From fbc51d31bb4929f6c6b16f2e15ada46fb8ca5db0 Mon Sep 17 00:00:00 2001 From: Nick Breed Date: Wed, 17 Sep 2025 13:46:51 +0000 Subject: [PATCH 01/26] First algorithm changed for modernization of configs --- .../include/rocprim/device/config_types.hpp | 312 +- .../config/device_radix_sort_onesweep.hpp | 11029 +++++++--------- .../device/detail/device_config_helper.hpp | 17 +- .../rocprim/device/device_radix_sort.hpp | 95 +- .../device/device_radix_sort_config.hpp | 51 +- 5 files changed, 5016 insertions(+), 6488 deletions(-) diff --git a/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp b/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp index e20a65b3e52..2643cd571a7 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp @@ -26,6 +26,8 @@ #include #include #include +#include +#include #include #include @@ -181,6 +183,123 @@ enum class target_arch : unsigned int }; #endif // DOXYGEN_SHOULD_SKIP_THIS +enum class rep +{ + amdgcn, + spirv, +}; + +enum class gen +{ + unknown, + gcn3, + gcn5, + cdna1, + cdna2, + cdna3, + cdna4, + rdna2, + rdna3, + rdna4, +}; + +enum class gpu +{ + generic, + v620, + rx6900, + rx7900, + rx9070, + mi50, + mi100, + mi210, + mi300x, + mi300a, + mi308x, + mi325x, + mi350x +}; + +constexpr gen gen_from_target_arch(target_arch i) +{ + switch(i) + { + case target_arch::gfx803: return gen::gcn3; + case target_arch::gfx900: + case target_arch::gfx906: return gen::gcn5; + case target_arch::gfx908: return gen::cdna1; + case target_arch::gfx90a: return gen::cdna2; + case target_arch::gfx942: return gen::cdna3; + case target_arch::gfx950: return gen::cdna4; + case target_arch::gfx1030: return gen::rdna2; + case target_arch::gfx1100: + case target_arch::gfx1102: return gen::rdna3; + case target_arch::gfx1200: + case target_arch::gfx1201: return gen::rdna4; + default: return gen::unknown; + } +} + +constexpr std::tuple target_gpu_names[] = { + std::make_tuple("MI350X", gpu::mi350x), + std::make_tuple("MI325X", gpu::mi325x), + std::make_tuple("MI308X", gpu::mi308x), + std::make_tuple("MI300A", gpu::mi300a), + std::make_tuple("MI300X", gpu::mi300x), + std::make_tuple("MI210", gpu::mi210), + std::make_tuple("MI100", gpu::mi100), + std::make_tuple("RX 9070", gpu::rx9070), + std::make_tuple("V620", gpu::v620), + std::make_tuple("RX 7900", gpu::rx7900), + std::make_tuple("RX 6900", gpu::rx6900), +}; + +template +struct comp_target +{ + static constexpr gen g = g_; + static constexpr target_arch i = i_; + static constexpr gpu s = s_; + static constexpr rep r = r_; +}; + +struct target +{ + gen g; + target_arch i; + gpu s; + rep r; + + constexpr target(target_arch i, gpu s = gpu::generic, rep r = rep::amdgcn) + : g(gen_from_target_arch(i)), i(i), s(s), r(r){}; + + constexpr target(gen g = gen::unknown, + target_arch i = target_arch::unknown, + gpu s = gpu::generic, + rep r = rep::amdgcn) + : g(g), i(i), s(s), r(r){}; + + template + constexpr target(CompTarget) + : g(CompTarget::g), i(CompTarget::i), s(CompTarget::s), r(CompTarget::r) + {} + + constexpr bool operator==(target other) const + { + return g == other.g && i == other.i && s == other.s && r == other.r; + } +}; + +template +struct comp_targets +{ + template + static constexpr void for_each(F f) + { + (f(Ts{}), ...); + } +}; + /** * \brief Checks if the first `n` characters of `rhs` are equal to `lhs` * @@ -363,20 +482,6 @@ struct radix_sort_config_selector = Config::template architecture_config::params.block_size; }; -template -struct radix_sort_onesweep_histogram_config_selector -{ - static constexpr unsigned int block_size - = Config::template architecture_config::params.histogram.block_size; -}; - -template -struct radix_sort_onesweep_sort_config_selector -{ - static constexpr unsigned int block_size - = Config::template architecture_config::params.sort.block_size; -}; - template struct segmented_radix_sort_warp_sort_small_config_selector { @@ -462,6 +567,147 @@ auto make_launch_plan(target_arch arch, Kernel kernel) -> launch_plan return {tuned_kernel.value(), kernel}; } +template +constexpr target most_common_config(target target_current) +{ + // Takes unknown as default. + target ret{}; + Targets::for_each( + [&](auto t) + { + // Skip unknown target for picking. + if(!(target{} == t)) + { + constexpr target_arch Arch = t.i; + constexpr gpu GPU = t.s; + constexpr gen Gen = t.g; + + // Update `ret` if the candidate `t` matches more specifically than the current `ret`. + // Priority order: prefer exact GPU match first; otherwise allow an Arch match (if GPU differs); + // finally allow a Gen match (if both Arch and GPU differ). This ensures we progressively + // refine the fallback from generic -> generation -> arch -> exact GPU. + if((GPU == target_current.s) + || (Arch == target_current.i + && (target_current.s != ret.s || ret.s == gpu::generic)) + || (Gen == target_current.g + && ((target_current.s != ret.s || ret.s == gpu::generic) + && (target_current.i != ret.i || ret.i == target_arch::unknown)))) + { + ret = target{t}; + } + } + }); + + return ret; +} + +template +constexpr typename Selector::param_type default_select_config(target t) +{ + using Targets = typename Selector::targets; + using Params = typename Selector::param_type; + + const target target_config = most_common_config(t); + + Params params{}; + + Targets::for_each( + [&](auto candidate) + { + if(target{candidate} == target_config) + { + params = Selector{candidate}.params; + } + }); + + return params; +} + +template +constexpr auto get_config(Config config, target t) +{ + if constexpr(std::is_same_v) + { + return default_select_config(t); + } + else + { + return config; + } +}; + +template +struct target_config2 +{ + constexpr static auto params = get_config(Config{}, target{Target{}}); + constexpr static auto wavefront = arch_wavefront_size(Target::i); + constexpr static auto arch = Target::i; +}; + +template +struct default_config_static_selector +{ + static constexpr auto block_size + = target_config2::params.kernel_config.block_size; +}; + +// trampoline_kernel that is fully specialized at compile-time for a single GPU architecture. +// By instantiating this template once per supported `target_arch`, the correct tuned config +// will be derived from the template. +template + class LaunchSelector> +ROCPRIM_KERNEL __launch_bounds__((LaunchSelector::block_size)) +void trampoline_kernel(Kernel kernel) +{ + using ArchConfig = target_config2; + +#if !defined(ROCPRIM_TARGET_SPIRV) || ROCPRIM_TARGET_SPIRV == 0 + if constexpr(Target::i == device_target_arch()) +#endif + { + kernel(ArchConfig{}); + } +#if !defined(ROCPRIM_TARGET_SPIRV) || ROCPRIM_TARGET_SPIRV == 0 + else + { + __builtin_unreachable(); + } +#endif +} + +template class LaunchSelector = default_config_selector> +auto make_launch_plan(target target_current, Kernel kernel) -> launch_plan +{ + using Targets = typename ConfigSelector::targets; + + std::optional tuned_kernel = std::nullopt; + + const target target_config = most_common_config(target_current); + + // The target config is always in Targets. + Targets::for_each( + [&](auto t) + { + if(target{t} == target_config) + { + tuned_kernel = trampoline_kernel; + } + }); + + return {tuned_kernel.value(), kernel}; +} + // Host-side helper running at run-time, picking the trampoline_kernel whose template // argument `Arch` matches the actual GPU we are executing on. template class LaunchSelector = default_config_static_selector, + class Kernel> +hipError_t execute_launch_plan( + target t, Kernel kernel, dim3 grid_size, dim3 block_size, size_t shmem, hipStream_t stream) +{ + const auto launch_plan + = make_launch_plan(t, kernel); + launch_plan.launch(grid_size, block_size, shmem, stream); + return hipGetLastError(); +} + #ifdef ROCPRIM_EXPERIMENTAL_SPIRV template #else @@ -602,6 +861,31 @@ inline hipError_t host_target_arch(const hipStream_t stream, target_arch& arch) return get_device_arch(device_id, arch); } +constexpr gpu get_target_gpu_from_name(std::string_view name) +{ + for(const auto& each : target_gpu_names) + { + if(name.find(std::get<0>(each)) != name.npos) + { + return std::get<1>(each); + } + } + return gpu::generic; +} + +inline hipError_t host_target_gpu(const hipStream_t stream, gpu& gpu) +{ + int device_id; + ROCPRIM_RETURN_ON_ERROR(get_device_from_stream(stream, device_id)); + + hipDeviceProp_t prop; + ROCPRIM_RETURN_ON_ERROR(hipGetDeviceProperties(&prop, device_id)); + + gpu = get_target_gpu_from_name(prop.name); + + return hipSuccess; +} + } // end namespace detail /// \brief Returns a number of threads in a hardware warp for the actual device. diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_radix_sort_onesweep.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_radix_sort_onesweep.hpp index 9e2fe8c795c..c761e93bc8d 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_radix_sort_onesweep.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_radix_sort_onesweep.hpp @@ -40,6404 +40,4637 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_radix_sort_onesweep_config - : radix_sort_onesweep_config_base::type -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<256, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 1>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 1>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 1>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 1>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 1>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 1>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<256, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<256, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<256, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<256, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<256, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 1>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<256, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<256, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<256, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<256, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<256, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<256, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<256, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<256, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<256, 18>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<256, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<256, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<256, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<256, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<256, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<256, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 4, - block_radix_rank_algorithm::basic> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 4, - block_radix_rank_algorithm::basic> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 22>, - 4, - block_radix_rank_algorithm::basic> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 4, - block_radix_rank_algorithm::basic> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 22>, - 4, - block_radix_rank_algorithm::basic> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<256, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<256, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<256, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<256, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<256, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<256, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 4, - block_radix_rank_algorithm::basic> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 4, - block_radix_rank_algorithm::basic> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 22>, - 4, - block_radix_rank_algorithm::basic> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<256, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<256, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<256, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<256, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<256, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<256, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<256, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 4, - block_radix_rank_algorithm::basic> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 22>, - 6, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 22>, - 4, - block_radix_rank_algorithm::basic> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 4>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<512, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 8>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<512, 22>, - 6, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : radix_sort_onesweep_config, - kernel_config<1024, 6>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : radix_sort_onesweep_config, - kernel_config<1024, 16>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : radix_sort_onesweep_config, - kernel_config<1024, 22>, - 8, - block_radix_rank_algorithm::match> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 22>, - 4, - block_radix_rank_algorithm::basic> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx950), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : radix_sort_onesweep_config, - kernel_config<1024, 12>, - 8, - block_radix_rank_algorithm::match> -{}; +template +constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< + std::is_same>::value, + radix_sort_onesweep_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 1}, + kernel_config_params{1024, 1}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 1}, + kernel_config_params{1024, 1}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{128, 22}, + kernel_config_params{128, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 22}, + kernel_config_params{512, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 22}, + kernel_config_params{256, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 1}, + kernel_config_params{1024, 1}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 1}, + kernel_config_params{1024, 1}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 1}, + kernel_config_params{1024, 1}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 1}, + kernel_config_params{1024, 1}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 1}, + kernel_config_params{1024, 1}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 1}, + kernel_config_params{1024, 1}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 1}, + kernel_config_params{1024, 1}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 1}, + kernel_config_params{1024, 1}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 1}, + kernel_config_params{1024, 1}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{128, 22}, + kernel_config_params{128, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 18}, + kernel_config_params{256, 18}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{128, 22}, + kernel_config_params{128, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 22}, + kernel_config_params{256, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 18}, + kernel_config_params{256, 18}, + 8, + block_radix_rank_algorithm::match + }; + } + // Default case if none of the conditions match + return radix_sort_onesweep_config_params_base(); +} + +template +constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< + std::is_same>::value, + radix_sort_onesweep_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{128, 16}, + kernel_config_params{128, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{128, 22}, + kernel_config_params{128, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 22}, + kernel_config_params{512, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{128, 22}, + kernel_config_params{128, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 18}, + kernel_config_params{512, 18}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 18}, + kernel_config_params{512, 18}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 1}, + kernel_config_params{1024, 1}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{128, 16}, + kernel_config_params{128, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{128, 16}, + kernel_config_params{128, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{128, 16}, + kernel_config_params{128, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{128, 16}, + kernel_config_params{128, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{128, 22}, + kernel_config_params{128, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 16}, + kernel_config_params{256, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{128, 22}, + kernel_config_params{128, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 22}, + kernel_config_params{512, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{128, 22}, + kernel_config_params{128, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 18}, + kernel_config_params{512, 18}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{128, 22}, + kernel_config_params{128, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 18}, + kernel_config_params{512, 18}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 18}, + kernel_config_params{512, 18}, + 8, + block_radix_rank_algorithm::match + }; + } + // Default case if none of the conditions match + return radix_sort_onesweep_config_params_base(); +} + +template +constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< + std::is_same>::value, + radix_sort_onesweep_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 6}, + kernel_config_params{256, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 22}, + kernel_config_params{256, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 16}, + kernel_config_params{256, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 22}, + kernel_config_params{256, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 6}, + kernel_config_params{256, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 6}, + kernel_config_params{256, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 22}, + kernel_config_params{256, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 22}, + kernel_config_params{256, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 18}, + kernel_config_params{256, 18}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 22}, + kernel_config_params{256, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 22}, + kernel_config_params{512, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 4, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 16}, + kernel_config_params{256, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 4, + block_radix_rank_algorithm::basic + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Default case if none of the conditions match + return radix_sort_onesweep_config_params_base(); +} + +template +constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< + std::is_same>::value, + radix_sort_onesweep_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 22}, + kernel_config_params{256, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 22}, + kernel_config_params{256, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 22}, + kernel_config_params{256, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 22}, + kernel_config_params{256, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 22}, + kernel_config_params{256, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 6}, + kernel_config_params{256, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 18}, + kernel_config_params{512, 18}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 16}, + kernel_config_params{256, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 22}, + kernel_config_params{256, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 22}, + kernel_config_params{256, 22}, + 7, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 18}, + kernel_config_params{256, 18}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 18}, + kernel_config_params{256, 18}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 18}, + kernel_config_params{1024, 18}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 22}, + kernel_config_params{256, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 22}, + kernel_config_params{256, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 22}, + kernel_config_params{256, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 18}, + kernel_config_params{1024, 18}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 22}, + kernel_config_params{1024, 22}, + 4, + block_radix_rank_algorithm::basic + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 16}, + kernel_config_params{256, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Default case if none of the conditions match + return radix_sort_onesweep_config_params_base(); +} + +template +constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< + std::is_same>::value, + radix_sort_onesweep_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 22}, + kernel_config_params{512, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, + 6, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 22}, + kernel_config_params{512, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 22}, + kernel_config_params{512, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 22}, + kernel_config_params{512, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 18}, + kernel_config_params{512, 18}, + 4, + block_radix_rank_algorithm::basic + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 22}, + kernel_config_params{512, 22}, + 4, + block_radix_rank_algorithm::basic + }; + } + // Default case if none of the conditions match + return radix_sort_onesweep_config_params_base(); +} + +template +constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< + std::is_same>::value, + radix_sort_onesweep_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 32}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 22}, + 6, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 22}, + kernel_config_params{1024, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 22}, + kernel_config_params{1024, 22}, + 4, + block_radix_rank_algorithm::basic + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Default case if none of the conditions match + return radix_sort_onesweep_config_params_base(); +} + +template +constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< + std::is_same>::value, + radix_sort_onesweep_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 18}, + kernel_config_params{1024, 18}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 18}, + kernel_config_params{512, 18}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 18}, + kernel_config_params{256, 18}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 18}, + kernel_config_params{1024, 18}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 18}, + kernel_config_params{512, 18}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 18}, + kernel_config_params{256, 18}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 18}, + kernel_config_params{512, 18}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 16}, + kernel_config_params{256, 16}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 18}, + kernel_config_params{256, 18}, + 8, + block_radix_rank_algorithm::match + }; + } + // Default case if none of the conditions match + return radix_sort_onesweep_config_params_base(); +} + +template +constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< + std::is_same>::value, + radix_sort_onesweep_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 32}, + kernel_config_params{256, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 22}, + 6, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 22}, + kernel_config_params{1024, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{512, 32}, + kernel_config_params{512, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 6}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 22}, + 8, + block_radix_rank_algorithm::match + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 22}, + kernel_config_params{1024, 22}, + 4, + block_radix_rank_algorithm::basic + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 32}, + kernel_config_params{1024, 12}, + 8, + block_radix_rank_algorithm::match + }; + } + // Default case if none of the conditions match + return radix_sort_onesweep_config_params_base(); +} + +template +constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< + std::is_same>::value, + radix_sort_onesweep_config_params> +{ + return radix_sort_onesweep_config_picker< + comp_target, + key_type, + value_type>(); +} + +// All the existing configs should be auto generated +using radix_sort_onesweep_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/device_config_helper.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/device_config_helper.hpp index 585e42fb6eb..9c44d87ca78 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/device_config_helper.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/device_config_helper.hpp @@ -229,16 +229,17 @@ struct reduce_config_tag // Calculate kernel configurations, such that it will not exceed shared memory maximum template -struct radix_sort_onesweep_config_base +constexpr radix_sort_onesweep_config_params radix_sort_onesweep_config_params_base() { - static constexpr unsigned int item_scale = ::rocprim::max(sizeof(Key), sizeof(Value)); + constexpr unsigned int item_scale = ::rocprim::max(sizeof(Key), sizeof(Value)); + constexpr unsigned int block_size = merge_sort_block_size(item_scale) * 4; - static constexpr unsigned int block_size = merge_sort_block_size(item_scale) * 4; - using type = radix_sort_onesweep_config< - kernel_config<256, 12>, - kernel_config, - 4>; -}; + return radix_sort_onesweep_config_params{ + kernel_config_params{256, 12}, + kernel_config_params{block_size, ::rocprim::max(1u, 65000u / block_size / item_scale)}, + 4 + }; +} struct reduce_config_params { diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_radix_sort.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_radix_sort.hpp index db3a92d33ad..a4d993087c0 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_radix_sort.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_radix_sort.hpp @@ -36,6 +36,7 @@ #include "../types.hpp" #include "../type_traits.hpp" +#include "config_types.hpp" #include "detail/config/device_radix_sort_onesweep.hpp" #include "detail/device_radix_sort.hpp" #include "device_transform.hpp" @@ -97,16 +98,18 @@ hipError_t radix_sort_onesweep_global_offsets(KeysInputIterator keys_input, { using key_type = typename std::iterator_traits::value_type; using value_type = typename std::iterator_traits::value_type; - using config = wrapped_radix_sort_onesweep_config; - detail::target_arch target_arch; - hipError_t result = host_target_arch(stream, target_arch); - if(result != hipSuccess) - { - return result; - } - const radix_sort_onesweep_config_params params - = dispatch_target_arch(target_arch); + using Selector = radix_sort_onesweep_config_selector; + + target_arch target_arch; + ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); + + gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + + const target get_target(target_arch, target_gpu); + + const radix_sort_onesweep_config_params params = get_config(Config{}, get_target); const unsigned int items_per_block = params.histogram.block_size * params.histogram.items_per_thread; @@ -147,15 +150,15 @@ hipError_t radix_sort_onesweep_global_offsets(KeysInputIterator keys_input, begin_bit, end_bit); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan( - target_arch, - onesweep_histograms_kernel, - dim3(blocks), - dim3(params.histogram.block_size), - 0, - stream)); + + ROCPRIM_RETURN_ON_ERROR( + execute_launch_plan( + get_target, + onesweep_histograms_kernel, + dim3(blocks), + dim3(params.histogram.block_size), + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("compute_global_digit_histograms", size, start); // Scan each histogram separately to get the final offsets. @@ -170,15 +173,16 @@ hipError_t radix_sort_onesweep_global_offsets(KeysInputIterator keys_input, onesweep_scan_histograms( global_digit_offsets); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan( - target_arch, - onesweep_scan_histograms_kernel, - dim3(digit_places), // One block for every digit place. - dim3(params.histogram.block_size), - 0, - stream)); + + ROCPRIM_RETURN_ON_ERROR( + execute_launch_plan( + get_target, + onesweep_scan_histograms_kernel, + dim3(digit_places), // One block for every digit place. + dim3(params.histogram.block_size), + 0, + stream)); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("scan_global_digit_histograms", bins, start); return hipSuccess; } @@ -214,12 +218,18 @@ hipError_t radix_sort_onesweep_iteration( { using key_type = typename std::iterator_traits::value_type; using value_type = typename std::iterator_traits::value_type; - using config = wrapped_radix_sort_onesweep_config; - detail::target_arch target_arch; + using Selector = radix_sort_onesweep_config_selector; + + target_arch target_arch; ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); - const radix_sort_onesweep_config_params params - = dispatch_target_arch(target_arch); + + gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + + const target get_target(target_arch, target_gpu); + + const radix_sort_onesweep_config_params params = get_config(Config{}, get_target); const unsigned int items_per_block = params.sort.block_size * params.sort.items_per_thread; const unsigned int current_radix_bits @@ -278,7 +288,8 @@ hipError_t radix_sort_onesweep_iteration( { auto onesweep_iteration_kernel = [=](auto arch_config) { - static constexpr auto params = decltype(arch_config)::params; + static constexpr radix_sort_onesweep_config_params params + = decltype(arch_config)::params; onesweep_iteration( - target_arch, + return execute_launch_plan( + get_target, onesweep_iteration_kernel, dim3(blocks), dim3(params.sort.block_size), @@ -373,6 +384,8 @@ hipError_t radix_sort_onesweep_impl( using value_type = typename std::iterator_traits::value_type; using offset_type = offset_type_t; + using Selector = radix_sort_onesweep_config_selector; + bool use_atomic_block_id; ROCPRIM_RETURN_ON_ERROR(check_if_using_atomic_block_id(stream, use_atomic_block_id)); const auto use_atomic_block_id_variant @@ -383,13 +396,17 @@ hipError_t radix_sort_onesweep_impl( [&](auto use_atomic_block_id) { using ordered_bid_type = block_id_wrapper; - using config = wrapped_radix_sort_onesweep_config; - detail::target_arch target_arch; + target_arch target_arch; ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); + gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + + const target get_target(target_arch, target_gpu); + const radix_sort_onesweep_config_params params - = dispatch_target_arch(target_arch); + = get_config(Config{}, get_target); const unsigned int sort_items_per_block = params.sort.block_size * params.sort.items_per_thread; diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_radix_sort_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_radix_sort_config.hpp index 52ac827660a..ed69e8479b3 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_radix_sort_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_radix_sort_config.hpp @@ -23,6 +23,7 @@ #include "config_types.hpp" #include "detail/config/device_radix_sort_block_sort.hpp" +#include "detail/config/device_radix_sort_onesweep.hpp" #include "detail/device_config_helper.hpp" /// \addtogroup primitivesmodule_deviceconfigs @@ -65,41 +66,33 @@ struct radix_sort_config namespace detail { -// sub-algorithm onesweep: -template -struct wrapped_radix_sort_onesweep_config +template +struct radix_sort_onesweep_config_selector { - template - struct architecture_config - { - static constexpr radix_sort_onesweep_config_params params = RadixSortOnesweepConfig(); - }; + using targets = radix_sort_onesweep_targets; + using param_type = radix_sort_onesweep_config_params; + + param_type params; + + template + constexpr radix_sort_onesweep_config_selector(Target) + : params(radix_sort_onesweep_config_picker()) + {} }; -template -struct wrapped_radix_sort_onesweep_config +template +struct radix_sort_onesweep_histogram_config_static_selector { - template - struct architecture_config - { - static constexpr radix_sort_onesweep_config_params params - = default_radix_sort_onesweep_config(Arch), Key, Value>(); - }; + static constexpr auto block_size = target_config2::params + .radix_sort_onesweep_config_params::histogram.block_size; }; -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template -template -constexpr radix_sort_onesweep_config_params - wrapped_radix_sort_onesweep_config::architecture_config< - Arch>::params; - -template -template -constexpr radix_sort_onesweep_config_params - wrapped_radix_sort_onesweep_config::architecture_config< - Arch>::params; -#endif // DOXYGEN_SHOULD_SKIP_THIS +template +struct radix_sort_onesweep_sort_config_static_selector +{ + static constexpr auto block_size = target_config2::params + .radix_sort_onesweep_config_params::sort.block_size; +}; // Sub-algorithm block_sort: template From a1439f62e6aef629560b92095656aa0ba4057c3c Mon Sep 17 00:00:00 2001 From: Nick Breed Date: Wed, 5 Nov 2025 08:01:24 +0000 Subject: [PATCH 02/26] Resolve "Update configs for new config system part 5" --- .../include/rocprim/device/config_types.hpp | 18 +- .../device/detail/config/device_scan.hpp | 1766 ++- .../detail/config/device_scan_by_key.hpp | 9105 +++++-------- .../device/detail/config/device_search_n.hpp | 1193 +- .../config/device_segmented_radix_sort.hpp | 11107 ++++++---------- .../device/detail/device_config_helper.hpp | 92 +- .../rocprim/device/detail/device_search_n.hpp | 57 +- .../include/rocprim/device/device_scan.hpp | 35 +- .../rocprim/device/device_scan_by_key.hpp | 22 +- .../device/device_scan_by_key_config.hpp | 40 +- .../rocprim/device/device_scan_config.hpp | 40 +- .../rocprim/device/device_search_n_config.hpp | 40 +- .../device/device_segmented_radix_sort.hpp | 57 +- .../device_segmented_radix_sort_config.hpp | 55 +- .../rocprim/device/device_segmented_scan.hpp | 28 +- .../rocprim/test/rocprim/test_device_scan.cpp | 55 +- .../test/rocprim/test_device_search_n.cpp | 75 +- .../test/rocprim/test_linking_new_scan.hpp | 30 +- 18 files changed, 9495 insertions(+), 14320 deletions(-) diff --git a/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp b/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp index 2643cd571a7..756be84f7b0 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp @@ -209,6 +209,7 @@ enum class gpu v620, rx6900, rx7900, + rx9060, rx9070, mi50, mi100, @@ -248,6 +249,7 @@ constexpr std::tuple target_gpu_names[] = { std::make_tuple("MI300X", gpu::mi300x), std::make_tuple("MI210", gpu::mi210), std::make_tuple("MI100", gpu::mi100), + std::make_tuple("RX 9060", gpu::rx9060), std::make_tuple("RX 9070", gpu::rx9070), std::make_tuple("V620", gpu::v620), std::make_tuple("RX 7900", gpu::rx7900), @@ -482,20 +484,6 @@ struct radix_sort_config_selector = Config::template architecture_config::params.block_size; }; -template -struct segmented_radix_sort_warp_sort_small_config_selector -{ - static constexpr unsigned int block_size - = Config::template architecture_config::params.warp_sort_config.block_size_small; -}; - -template -struct segmented_radix_sort_warp_sort_meduim_config_selector -{ - static constexpr unsigned int block_size - = Config::template architecture_config::params.warp_sort_config.block_size_medium; -}; - template struct target_config { @@ -624,7 +612,7 @@ constexpr typename Selector::param_type default_select_config(target t) } template -constexpr auto get_config(Config config, target t) +constexpr typename Selector::param_type get_config(Config config, target t) { if constexpr(std::is_same_v) { diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_scan.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_scan.hpp index 457e83e157f..5ff7018c330 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_scan.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_scan.hpp @@ -40,986 +40,792 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_scan_config : default_scan_config_base::type -{}; - -// Based on value_type = double -template -struct default_scan_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_config<128, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = float -template -struct default_scan_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_config<64, - 5, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = rocprim::half -template -struct default_scan_config(target_arch::gfx1030), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : scan_config<64, - 12, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_scan_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_config<256, - 2, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = int64_t -template -struct default_scan_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_config<256, - 4, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = int -template -struct default_scan_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_config<64, - 5, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = short -template -struct default_scan_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_config<64, - 12, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = int8_t -template -struct default_scan_config(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : scan_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = double -template -struct default_scan_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = float -template -struct default_scan_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_config<256, - 4, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = rocprim::half -template -struct default_scan_config(target_arch::gfx1100), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : scan_config<64, - 22, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_scan_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_config<256, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = int64_t -template -struct default_scan_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = int -template -struct default_scan_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_config<256, - 4, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = short -template -struct default_scan_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_config<64, - 22, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = int8_t -template -struct default_scan_config(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : scan_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = double -template -struct default_scan_config< - static_cast(target_arch::gfx1200), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = float -template -struct default_scan_config< - static_cast(target_arch::gfx1200), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_config<256, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = rocprim::half -template -struct default_scan_config(target_arch::gfx1200), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : scan_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = int64_t -template -struct default_scan_config< - static_cast(target_arch::gfx1200), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_config<256, - 4, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = int -template -struct default_scan_config< - static_cast(target_arch::gfx1200), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_config<256, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = short -template -struct default_scan_config< - static_cast(target_arch::gfx1200), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = int8_t -template -struct default_scan_config(target_arch::gfx1200), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : scan_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = double -template -struct default_scan_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = float -template -struct default_scan_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_config<256, - 6, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = rocprim::half -template -struct default_scan_config(target_arch::gfx1201), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : scan_config<256, - 12, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = int64_t -template -struct default_scan_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = int -template -struct default_scan_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_config<256, - 6, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = short -template -struct default_scan_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_config<256, - 12, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = int8_t -template -struct default_scan_config(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : scan_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = double -template -struct default_scan_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_config<128, - 10, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = float -template -struct default_scan_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_config<256, - 15, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = rocprim::half -template -struct default_scan_config(target_arch::gfx906), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : scan_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_scan_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_config<64, - 15, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = int64_t -template -struct default_scan_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = int -template -struct default_scan_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_config<256, - 15, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = short -template -struct default_scan_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = int8_t -template -struct default_scan_config(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : scan_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = double -template -struct default_scan_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_config<256, - 6, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = float -template -struct default_scan_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = rocprim::half -template -struct default_scan_config(target_arch::gfx908), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : scan_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_scan_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_config<128, - 10, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = int64_t -template -struct default_scan_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_config<256, - 6, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = int -template -struct default_scan_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = short -template -struct default_scan_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_config<128, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = int8_t -template -struct default_scan_config(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : scan_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = double -template -struct default_scan_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_config<128, - 10, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = float -template -struct default_scan_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_config<128, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = rocprim::half -template -struct default_scan_config(target_arch::gfx90a), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : scan_config<64, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_scan_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_config<64, - 15, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = int64_t -template -struct default_scan_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_config<128, - 7, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = int -template -struct default_scan_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_config<128, - 10, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = short -template -struct default_scan_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_config<64, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = int8_t -template -struct default_scan_config(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : scan_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = double -template -struct default_scan_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_config<256, - 6, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = float -template -struct default_scan_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = rocprim::half -template -struct default_scan_config(target_arch::unknown), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : scan_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_scan_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_config<128, - 10, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = int64_t -template -struct default_scan_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_config<256, - 6, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = int -template -struct default_scan_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = short -template -struct default_scan_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_config<128, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = int8_t -template -struct default_scan_config(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : scan_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = double -template -struct default_scan_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_config<256, - 15, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = float -template -struct default_scan_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = rocprim::half -template -struct default_scan_config(target_arch::gfx942), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : scan_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_scan_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = int64_t -template -struct default_scan_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_config<256, - 15, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on value_type = int -template -struct default_scan_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_config<256, - 21, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = short -template -struct default_scan_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = int8_t -template -struct default_scan_config(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : scan_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_scan_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_config<256, - 11, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; +template +constexpr auto scan_config_picker() -> std::enable_if_t< + std::is_same>::value, + scan_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_config_params{ + {128, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_config_params{ + {64, 5}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return scan_config_params{ + {64, 12}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_config_params{ + {256, 2}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_config_params{ + {256, 4}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_config_params{ + {64, 5}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_config_params{ + {64, 12}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return scan_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return scan_config_params_base(); +} + +template +constexpr auto scan_config_picker() -> std::enable_if_t< + std::is_same>::value, + scan_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_config_params{ + {256, 4}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return scan_config_params{ + {64, 22}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_config_params{ + {256, 4}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_config_params{ + {64, 22}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return scan_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return scan_config_params_base(); +} + +template +constexpr auto scan_config_picker() -> std::enable_if_t< + std::is_same>::value, + scan_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return scan_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_config_params{ + {256, 4}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return scan_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return scan_config_params_base(); +} + +template +constexpr auto scan_config_picker() -> std::enable_if_t< + std::is_same>::value, + scan_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_config_params{ + {256, 6}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return scan_config_params{ + {256, 12}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_config_params{ + {256, 6}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_config_params{ + {256, 12}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return scan_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_config_params{ + {256, 11}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return scan_config_params_base(); +} + +template +constexpr auto scan_config_picker() -> std::enable_if_t< + std::is_same>::value, + scan_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_config_params{ + {128, 10}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_config_params{ + {256, 15}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return scan_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_config_params{ + {64, 15}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_config_params{ + {64, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_config_params{ + {256, 15}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return scan_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return scan_config_params_base(); +} + +template +constexpr auto scan_config_picker() -> std::enable_if_t< + std::is_same>::value, + scan_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_config_params{ + {256, 6}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return scan_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_config_params{ + {128, 10}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_config_params{ + {256, 6}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_config_params{ + {128, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return scan_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return scan_config_params_base(); +} + +template +constexpr auto scan_config_picker() -> std::enable_if_t< + std::is_same>::value, + scan_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_config_params{ + {128, 10}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_config_params{ + {128, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return scan_config_params{ + {64, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_config_params{ + {64, 15}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_config_params{ + {128, 7}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_config_params{ + {128, 10}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_config_params{ + {64, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return scan_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Default case if none of the conditions match + return scan_config_params_base(); +} + +template +constexpr auto scan_config_picker() -> std::enable_if_t< + std::is_same>::value, + scan_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_config_params{ + {256, 15}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return scan_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_config_params{ + {256, 15}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_config_params{ + {256, 21}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return scan_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return scan_config_params_base(); +} + +template +constexpr auto scan_config_picker() -> std::enable_if_t< + std::is_same>::value, + scan_config_params> +{ + return scan_config_picker, + value_type>(); +} + +// All the existing configs should be auto generated +using scan_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_scan_by_key.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_scan_by_key.hpp index 8436e810f2c..800e48e6862 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_scan_by_key.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_scan_by_key.hpp @@ -40,5437 +40,3680 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_scan_by_key_config : default_scan_by_key_config_base::type -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 4, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 11, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 3, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = double, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 21, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 2, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 9, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 3, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 7, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<64, - 19, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 4, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 4, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 10, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1))>> - : scan_by_key_config<64, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 4, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<128, - 20, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 10, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 4, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 13, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 3, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 20, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 5, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 3, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 3, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 7, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<64, - 19, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 4, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 9, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<64, - 6, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = short, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 10, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<64, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 4, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_by_key_config<128, - 17, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 5, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1))>> - : scan_by_key_config<128, - 22, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = double, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = double, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<64, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 4, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = float, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 20, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 23, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 17, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 4, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 11, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 17, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<64, - 15, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 20, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<64, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 4, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 22, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 23, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<128, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = short, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 17, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_by_key_config<128, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 20, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 4, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = double, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 20, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = double, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<128, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 4, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 5, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 6, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 10, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 20, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 4, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 20, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<128, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 4, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 5, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 6, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = short, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = short, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 20, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 12, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 7, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 23, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 7, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<64, - 15, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = float, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 15, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = float, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<64, - 20, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 7, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_by_key_config<64, - 18, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 15, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1))>> - : scan_by_key_config<64, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<64, - 19, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 7, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 23, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 7, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<64, - 15, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 20, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<64, - 20, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 7, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<64, - 18, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 18, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 22, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<64, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_by_key_config<64, - 15, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_by_key_config<64, - 18, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 13, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1))>> - : scan_by_key_config<64, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<128, - 6, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = double, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<64, - 10, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<128, - 10, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 12, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = float, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<128, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_by_key_config<64, - 12, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_by_key_config<64, - 21, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_by_key_config<128, - 21, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<64, - 10, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 19, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<64, - 9, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<64, - 10, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<128, - 10, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 12, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<128, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<64, - 12, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<64, - 18, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = short, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 18, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 20, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_by_key_config<64, - 15, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_by_key_config<64, - 18, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_by_key_config<128, - 18, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 20, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<128, - 5, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<64, - 19, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 19, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<128, - 15, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<64, - 21, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<128, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<128, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 3, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_by_key_config<64, - 21, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 9, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 18, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 18, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<64, - 10, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<128, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<128, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<128, - 5, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<64, - 19, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 5, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<64, - 21, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<128, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<128, - 6, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<64, - 21, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 9, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = short, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 22, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 18, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 3, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_by_key_config<64, - 21, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_by_key_config<128, - 21, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 18, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<128, - 6, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = double, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<64, - 10, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<128, - 10, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 12, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = float, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<128, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_by_key_config<64, - 12, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_by_key_config<64, - 21, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_by_key_config<128, - 21, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<64, - 10, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 19, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<64, - 9, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<64, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<64, - 10, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<128, - 10, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 12, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<128, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<64, - 12, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<64, - 18, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = short, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 18, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 20, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_by_key_config<64, - 15, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_by_key_config<64, - 18, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_by_key_config<128, - 18, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 20, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<128, - 21, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = double, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<128, - 23, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 21, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 23, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = float, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 21, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 23, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<128, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<128, - 21, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<128, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<128, - 21, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<128, - 23, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 21, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 23, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 19, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 21, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 23, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 21, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<128, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 20, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = double, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 12, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 7, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 11, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 9, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = float, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 20, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_by_key_config<128, - 18, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 20, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<128, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 9, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 6, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<128, - 10, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<128, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<128, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 20, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 12, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 7, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 11, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 9, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 8, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 20, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 14, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = short, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : scan_by_key_config<128, - 18, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))>> - : scan_by_key_config<256, - 20, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : scan_by_key_config<256, - 13, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : scan_by_key_config<256, - 16, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : scan_by_key_config<256, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::reduce_then_scan> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : scan_by_key_config<256, - 18, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_scan_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1))>> - : scan_by_key_config<128, - 24, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - block_scan_algorithm::using_warp_scan> -{}; +template +constexpr auto scan_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + scan_by_key_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 4}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 11}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 3}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 21}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 2}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 9}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 3}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {64, 7}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {64, 19}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 4}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 4}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {64, 10}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {64, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 4}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {128, 20}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 10}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 4}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 13}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 3}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 20}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 5}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 3}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 3}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {64, 7}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {64, 19}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 4}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 9}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {64, 6}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {64, 10}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {64, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 4}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {128, 17}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 5}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {64, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {128, 22}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return scan_by_key_config_params_base(); +} + +template +constexpr auto scan_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + scan_by_key_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {64, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 4}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {64, 20}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 23}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 17}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 4}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 11}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {64, 17}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {64, 15}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 20}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {64, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 4}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {64, 22}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 23}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {128, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 17}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {128, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 20}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return scan_by_key_config_params_base(); +} + +template +constexpr auto scan_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + scan_by_key_config_params> +{ + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 4}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 20}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {128, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 4}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 5}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 6}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 10}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 20}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 4}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 20}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {128, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 4}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 5}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 6}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 20}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 12}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Default case if none of the conditions match + return scan_by_key_config_params_base(); +} + +template +constexpr auto scan_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + scan_by_key_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 7}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {64, 23}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {64, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 7}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {64, 15}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 15}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {64, 20}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 7}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {64, 18}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 15}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {64, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {64, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {64, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {64, 19}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {64, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {64, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {64, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 7}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {64, 23}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {64, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 7}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {64, 15}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {64, 20}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {64, 20}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 7}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {64, 18}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 18}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {64, 22}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {64, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {64, 15}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {64, 18}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 13}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {64, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {64, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Default case if none of the conditions match + return scan_by_key_config_params_base(); +} + +template +constexpr auto scan_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + scan_by_key_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {128, 6}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {64, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {64, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {64, 10}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {128, 10}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 12}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {128, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {64, 12}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {64, 21}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {128, 21}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {64, 10}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {64, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {64, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {64, 19}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {64, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {64, 9}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {64, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {64, 10}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {128, 10}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 12}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {128, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {64, 12}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {64, 18}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 18}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 20}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {64, 15}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {64, 18}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {128, 18}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 20}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Default case if none of the conditions match + return scan_by_key_config_params_base(); +} + +template +constexpr auto scan_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + scan_by_key_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {128, 5}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {64, 19}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {64, 19}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {128, 15}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {64, 21}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {128, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {64, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {128, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 3}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {64, 21}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 9}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 18}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 18}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {64, 10}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {64, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {128, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {128, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {64, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {128, 5}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {64, 19}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {64, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 5}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {64, 21}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {64, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {128, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {128, 6}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {64, 21}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 9}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 22}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 18}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 3}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {64, 21}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {128, 21}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 18}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Default case if none of the conditions match + return scan_by_key_config_params_base(); +} + +template +constexpr auto scan_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + scan_by_key_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {128, 21}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {128, 23}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 21}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 23}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 21}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 23}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {128, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {128, 21}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {128, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {128, 21}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {128, 23}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 21}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 23}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 19}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 21}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 23}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 21}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Default case if none of the conditions match + return scan_by_key_config_params_base(); +} + +template +constexpr auto scan_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + scan_by_key_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {128, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 20}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 12}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 7}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 11}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 9}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 20}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {128, 18}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 20}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {128, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 9}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 6}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {128, 10}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {128, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {128, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 20}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 12}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 7}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 11}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 9}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 20}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 14}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {128, 18}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {256, 20}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return scan_by_key_config_params{ + {256, 13}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return scan_by_key_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return scan_by_key_config_params{ + {256, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::reduce_then_scan + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return scan_by_key_config_params{ + {256, 18}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1))) + { + return scan_by_key_config_params{ + {128, 24}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return scan_by_key_config_params_base(); +} + +template +constexpr auto scan_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + scan_by_key_config_params> +{ + return scan_by_key_config_picker< + comp_target, + key_type, + value_type>(); +} + +// All the existing configs should be auto generated +using scan_by_key_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_search_n.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_search_n.hpp index 30a3be07d2f..5906b95f35a 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_search_n.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_search_n.hpp @@ -40,618 +40,587 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_search_n_config : default_search_n_config_base::type -{}; - -// Based on data_type = double -template -struct default_search_n_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : search_n_config<512, 4, 8> -{}; - -// Based on data_type = float -template -struct default_search_n_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : search_n_config<256, 4, 8> -{}; - -// Based on data_type = rocprim::half -template -struct default_search_n_config(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> - : search_n_config<1024, 8, 16> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_search_n_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : search_n_config<1024, 1, 8> -{}; - -// Based on data_type = int64_t -template -struct default_search_n_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : search_n_config<512, 4, 16> -{}; - -// Based on data_type = int -template -struct default_search_n_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : search_n_config<256, 4, 8> -{}; - -// Based on data_type = short -template -struct default_search_n_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : search_n_config<1024, 8, 16> -{}; - -// Based on data_type = int8_t -template -struct default_search_n_config(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> - : search_n_config<1024, 16, 16> -{}; - -// Based on data_type = double -template -struct default_search_n_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : search_n_config<256, 4, 8> -{}; - -// Based on data_type = float -template -struct default_search_n_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : search_n_config<256, 4, 8> -{}; - -// Based on data_type = rocprim::half -template -struct default_search_n_config(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> - : search_n_config<1024, 16, 16> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_search_n_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : search_n_config<256, 1, 8> -{}; - -// Based on data_type = int64_t -template -struct default_search_n_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : search_n_config<256, 4, 8> -{}; - -// Based on data_type = int -template -struct default_search_n_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : search_n_config<512, 4, 12> -{}; - -// Based on data_type = short -template -struct default_search_n_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : search_n_config<1024, 16, 16> -{}; - -// Based on data_type = int8_t -template -struct default_search_n_config(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> - : search_n_config<1024, 16, 16> -{}; - -// Based on data_type = double -template -struct default_search_n_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : search_n_config<128, 2, 8> -{}; - -// Based on data_type = float -template -struct default_search_n_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : search_n_config<256, 2, 8> -{}; - -// Based on data_type = rocprim::half -template -struct default_search_n_config(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> - : search_n_config<256, 4, 8> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_search_n_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : search_n_config<256, 1, 4> -{}; - -// Based on data_type = int64_t -template -struct default_search_n_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : search_n_config<128, 2, 4> -{}; - -// Based on data_type = int -template -struct default_search_n_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : search_n_config<256, 2, 8> -{}; - -// Based on data_type = short -template -struct default_search_n_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : search_n_config<256, 4, 8> -{}; - -// Based on data_type = int8_t -template -struct default_search_n_config(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> - : search_n_config<256, 4, 8> -{}; - -// Based on data_type = double -template -struct default_search_n_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : search_n_config<1024, 2, 4> -{}; - -// Based on data_type = float -template -struct default_search_n_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : search_n_config<256, 2, 8> -{}; - -// Based on data_type = rocprim::half -template -struct default_search_n_config(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> - : search_n_config<256, 4, 12> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_search_n_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : search_n_config<512, 1, 12> -{}; - -// Based on data_type = int64_t -template -struct default_search_n_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : search_n_config<1024, 2, 12> -{}; - -// Based on data_type = int -template -struct default_search_n_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : search_n_config<256, 2, 8> -{}; - -// Based on data_type = short -template -struct default_search_n_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : search_n_config<1024, 4, 8> -{}; - -// Based on data_type = int8_t -template -struct default_search_n_config(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> - : search_n_config<512, 4, 8> -{}; - -// Based on data_type = double -template -struct default_search_n_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : search_n_config<128, 2, 8> -{}; - -// Based on data_type = float -template -struct default_search_n_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : search_n_config<256, 2, 8> -{}; - -// Based on data_type = rocprim::half -template -struct default_search_n_config(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> - : search_n_config<128, 4, 8> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_search_n_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : search_n_config<128, 1, 4> -{}; - -// Based on data_type = int64_t -template -struct default_search_n_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : search_n_config<128, 2, 8> -{}; - -// Based on data_type = int -template -struct default_search_n_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : search_n_config<256, 2, 8> -{}; - -// Based on data_type = short -template -struct default_search_n_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : search_n_config<128, 4, 8> -{}; - -// Based on data_type = int8_t -template -struct default_search_n_config(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> - : search_n_config<128, 4, 8> -{}; - -// Based on data_type = double -template -struct default_search_n_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : search_n_config<1024, 2, 4> -{}; - -// Based on data_type = float -template -struct default_search_n_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : search_n_config<256, 2, 4> -{}; - -// Based on data_type = rocprim::half -template -struct default_search_n_config(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> - : search_n_config<256, 4, 12> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_search_n_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : search_n_config<512, 1, 12> -{}; - -// Based on data_type = int64_t -template -struct default_search_n_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : search_n_config<1024, 2, 16> -{}; - -// Based on data_type = int -template -struct default_search_n_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : search_n_config<256, 2, 8> -{}; - -// Based on data_type = short -template -struct default_search_n_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : search_n_config<1024, 4, 8> -{}; - -// Based on data_type = int8_t -template -struct default_search_n_config(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> - : search_n_config<512, 4, 8> -{}; - -// Based on data_type = double -template -struct default_search_n_config< - static_cast(target_arch::gfx1201), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : search_n_config<64, 2, 4> -{}; - -// Based on data_type = float -template -struct default_search_n_config< - static_cast(target_arch::gfx1201), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : search_n_config<128, 4, 12> -{}; - -// Based on data_type = rocprim::half -template -struct default_search_n_config(target_arch::gfx1201), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> - : search_n_config<128, 8, 12> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_search_n_config< - static_cast(target_arch::gfx1201), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : search_n_config<64, 1, 8> -{}; - -// Based on data_type = int64_t -template -struct default_search_n_config< - static_cast(target_arch::gfx1201), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : search_n_config<64, 2, 12> -{}; - -// Based on data_type = int -template -struct default_search_n_config< - static_cast(target_arch::gfx1201), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : search_n_config<64, 4, 4> -{}; - -// Based on data_type = short -template -struct default_search_n_config< - static_cast(target_arch::gfx1201), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : search_n_config<64, 8, 8> -{}; - -// Based on data_type = int8_t -template -struct default_search_n_config(target_arch::gfx1201), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> - : search_n_config<64, 16, 4> -{}; - -// Based on data_type = double -template -struct default_search_n_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : search_n_config<256, 2, 8> -{}; - -// Based on data_type = float -template -struct default_search_n_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : search_n_config<128, 4, 12> -{}; - -// Based on data_type = rocprim::half -template -struct default_search_n_config(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> - : search_n_config<256, 4, 16> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_search_n_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : search_n_config<128, 2, 12> -{}; - -// Based on data_type = int64_t -template -struct default_search_n_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : search_n_config<256, 2, 12> -{}; - -// Based on data_type = int -template -struct default_search_n_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : search_n_config<128, 4, 16> -{}; - -// Based on data_type = short -template -struct default_search_n_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : search_n_config<256, 4, 8> -{}; - -// Based on data_type = int8_t -template -struct default_search_n_config(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> - : search_n_config<1024, 8, 8> -{}; +template +constexpr auto search_n_config_picker() -> std::enable_if_t< + std::is_same>::value, + search_n_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return search_n_config_params{ + {512, 4}, + 8 + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return search_n_config_params{ + {256, 4}, + 8 + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return search_n_config_params{ + {1024, 8}, + 16 + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return search_n_config_params{ + {1024, 1}, + 8 + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return search_n_config_params{ + {512, 4}, + 16 + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return search_n_config_params{ + {256, 4}, + 8 + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return search_n_config_params{ + {1024, 8}, + 16 + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return search_n_config_params{ + {1024, 16}, + 16 + }; + } + // Default case if none of the conditions match + return search_n_config_params_base(); +} + +template +constexpr auto search_n_config_picker() -> std::enable_if_t< + std::is_same>::value, + search_n_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return search_n_config_params{ + {256, 4}, + 8 + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return search_n_config_params{ + {256, 4}, + 8 + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return search_n_config_params{ + {1024, 16}, + 16 + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return search_n_config_params{ + {256, 1}, + 8 + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return search_n_config_params{ + {256, 4}, + 8 + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return search_n_config_params{ + {512, 4}, + 12 + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return search_n_config_params{ + {1024, 16}, + 16 + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return search_n_config_params{ + {1024, 16}, + 16 + }; + } + // Default case if none of the conditions match + return search_n_config_params_base(); +} + +template +constexpr auto search_n_config_picker() -> std::enable_if_t< + std::is_same>::value, + search_n_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return search_n_config_params{ + {128, 2}, + 8 + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return search_n_config_params{ + {256, 2}, + 8 + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return search_n_config_params{ + {256, 4}, + 8 + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return search_n_config_params{ + {256, 1}, + 4 + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return search_n_config_params{ + {128, 2}, + 4 + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return search_n_config_params{ + {256, 2}, + 8 + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return search_n_config_params{ + {256, 4}, + 8 + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return search_n_config_params{ + {256, 4}, + 8 + }; + } + // Default case if none of the conditions match + return search_n_config_params_base(); +} + +template +constexpr auto search_n_config_picker() -> std::enable_if_t< + std::is_same>::value, + search_n_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return search_n_config_params{ + {1024, 2}, + 4 + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return search_n_config_params{ + {256, 2}, + 8 + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return search_n_config_params{ + {256, 4}, + 12 + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return search_n_config_params{ + {512, 1}, + 12 + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return search_n_config_params{ + {1024, 2}, + 12 + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return search_n_config_params{ + {256, 2}, + 8 + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return search_n_config_params{ + {1024, 4}, + 8 + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return search_n_config_params{ + {512, 4}, + 8 + }; + } + // Default case if none of the conditions match + return search_n_config_params_base(); +} + +template +constexpr auto search_n_config_picker() -> std::enable_if_t< + std::is_same>::value, + search_n_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return search_n_config_params{ + {128, 2}, + 8 + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return search_n_config_params{ + {256, 2}, + 8 + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return search_n_config_params{ + {128, 4}, + 8 + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return search_n_config_params{ + {128, 1}, + 4 + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return search_n_config_params{ + {128, 2}, + 8 + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return search_n_config_params{ + {256, 2}, + 8 + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return search_n_config_params{ + {128, 4}, + 8 + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return search_n_config_params{ + {128, 4}, + 8 + }; + } + // Default case if none of the conditions match + return search_n_config_params_base(); +} + +template +constexpr auto search_n_config_picker() -> std::enable_if_t< + std::is_same>::value, + search_n_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return search_n_config_params{ + {64, 2}, + 4 + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return search_n_config_params{ + {128, 4}, + 12 + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return search_n_config_params{ + {128, 8}, + 12 + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return search_n_config_params{ + {64, 1}, + 8 + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return search_n_config_params{ + {64, 2}, + 12 + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return search_n_config_params{ + {64, 4}, + 4 + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return search_n_config_params{ + {64, 8}, + 8 + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return search_n_config_params{ + {64, 16}, + 4 + }; + } + // Default case if none of the conditions match + return search_n_config_params_base(); +} + +template +constexpr auto search_n_config_picker() -> std::enable_if_t< + std::is_same>::value, + search_n_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return search_n_config_params{ + {256, 2}, + 8 + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return search_n_config_params{ + {128, 4}, + 12 + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return search_n_config_params{ + {256, 4}, + 16 + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return search_n_config_params{ + {128, 2}, + 12 + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return search_n_config_params{ + {256, 2}, + 12 + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return search_n_config_params{ + {128, 4}, + 16 + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return search_n_config_params{ + {256, 4}, + 8 + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return search_n_config_params{ + {1024, 8}, + 8 + }; + } + // Default case if none of the conditions match + return search_n_config_params_base(); +} + +template +constexpr auto search_n_config_picker() -> std::enable_if_t< + std::is_same>::value, + search_n_config_params> +{ + return search_n_config_picker< + comp_target, + data_type>(); +} + +// All the existing configs should be auto generated +using search_n_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_segmented_radix_sort.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_segmented_radix_sort.hpp index 628fab9a2fa..43410d17367 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_segmented_radix_sort.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_segmented_radix_sort.hpp @@ -40,6966 +40,4153 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_segmented_radix_sort_config : default_segmented_radix_sort_config_base<6>::type -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<128, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<128, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<128, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<128, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 8, 256, 5, 16, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<16, 8, 256, 5, 32, 16, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 7, - kernel_config<256, 17>, - typename std::conditional<1, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 4>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 16>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_segmented_radix_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : segmented_radix_sort_config< - 8, - kernel_config<256, 8>, - typename std::conditional<1, - WarpSortConfig<8, 4, 256, 64, 16, 8, 256>, - DisabledWarpSortConfig>::type, - 1> -{}; +template +constexpr auto segmented_radix_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + segmented_radix_sort_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 8, 256, 5, 16, 16, 256}, + 1 + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 8, 256, 5, 16, 16, 256}, + 1 + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 8, 256, 5, 16, 16, 256}, + 1 + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 16, 8, 256, 5, 32, 16, 256}, + 1 + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 16, 8, 256, 5, 32, 16, 256}, + 1 + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 8, 256, 5, 16, 16, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 16, 8, 256, 5, 32, 16, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 16, 8, 256, 5, 32, 16, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 16, 8, 256, 5, 32, 16, 256}, + 1 + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 16, 8, 256, 5, 32, 16, 256}, + 1 + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 16, 8, 256, 5, 32, 16, 256}, + 1 + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 16, 8, 256, 5, 32, 16, 256}, + 1 + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Default case if none of the conditions match + return segmented_radix_sort_config_params_base(); +} + +template +constexpr auto segmented_radix_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + segmented_radix_sort_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{128, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{128, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{128, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Default case if none of the conditions match + return segmented_radix_sort_config_params_base(); +} + +template +constexpr auto segmented_radix_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + segmented_radix_sort_config_params> +{ + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Default case if none of the conditions match + return segmented_radix_sort_config_params_base(); +} + +template +constexpr auto segmented_radix_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + segmented_radix_sort_config_params> +{ + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Default case if none of the conditions match + return segmented_radix_sort_config_params_base(); +} + +template +constexpr auto segmented_radix_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + segmented_radix_sort_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 8, 256, 5, 16, 16, 256}, + 1 + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{128, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 8, 256, 5, 16, 16, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 16, 8, 256, 5, 32, 16, 256}, + 1 + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Default case if none of the conditions match + return segmented_radix_sort_config_params_base(); +} + +template +constexpr auto segmented_radix_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + segmented_radix_sort_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 8, 256, 5, 16, 16, 256}, + 1 + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 8, 256, 5, 16, 16, 256}, + 1 + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 8, 256, 5, 16, 16, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 8, 256, 5, 16, 16, 256}, + 1 + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 16, 8, 256, 5, 32, 16, 256}, + 1 + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Default case if none of the conditions match + return segmented_radix_sort_config_params_base(); +} + +template +constexpr auto segmented_radix_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + segmented_radix_sort_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 8, 256, 5, 16, 16, 256}, + 1 + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 8, 256, 5, 16, 16, 256}, + 1 + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 7, + kernel_config_params{256, 17}, + warp_sort_config_params{1, 32, 4, 256, 3000, 32, 4, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Default case if none of the conditions match + return segmented_radix_sort_config_params_base(); +} + +template +constexpr auto segmented_radix_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + segmented_radix_sort_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 16}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 4}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return segmented_radix_sort_config_params{ + 8, + kernel_config_params{256, 8}, + warp_sort_config_params{1, 8, 4, 256, 64, 16, 8, 256}, + 1 + }; + } + // Default case if none of the conditions match + return segmented_radix_sort_config_params_base(); +} + +template +constexpr auto segmented_radix_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + segmented_radix_sort_config_params> +{ + return segmented_radix_sort_config_picker< + comp_target, + key_type, + value_type>(); +} + +// All the existing configs should be auto generated +using segmented_radix_sort_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/device_config_helper.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/device_config_helper.hpp index 9c44d87ca78..dd69d7a44a4 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/device_config_helper.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/device_config_helper.hpp @@ -353,16 +353,18 @@ struct scan_by_key_config_tag {}; template -struct default_scan_config_base +constexpr scan_config_params scan_config_params_base() { - static constexpr unsigned int item_scale + constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div(sizeof(Value), sizeof(int)); - using type = scan_config::value, - ::rocprim::max(1u, 16u / item_scale), - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose, - ::rocprim::block_scan_algorithm::using_warp_scan>; + return scan_config_params{ + {limit_block_size<256U, sizeof(Value), ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 16u / item_scale)}, + ::rocprim::block_load_method::block_load_transpose, + ::rocprim::block_store_method::block_store_transpose, + ::rocprim::block_scan_algorithm::using_warp_scan + }; }; /// \brief Provides the kernel parameters for exclusive_scan_by_key and inclusive_scan_by_key based @@ -427,17 +429,19 @@ namespace detail { template -struct default_scan_by_key_config_base +constexpr scan_by_key_config_params scan_by_key_config_params_base() { - static constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div( - sizeof(Key) + sizeof(Value), 2 * sizeof(int)); + constexpr unsigned int item_scale + = ::rocprim::detail::ceiling_div(sizeof(Key) + sizeof(Value), + 2 * sizeof(int)); - using type = scan_by_key_config< - limit_block_size<256U, sizeof(Key) + sizeof(Value), ROCPRIM_WARP_SIZE_64>::value, - ::rocprim::max(1u, 16u / item_scale), + return scan_by_key_config_params{ + {limit_block_size<256U, sizeof(Key) + sizeof(Value), ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 16u / item_scale)}, ::rocprim::block_load_method::block_load_transpose, ::rocprim::block_store_method::block_store_transpose, - ::rocprim::block_scan_algorithm::using_warp_scan>; + ::rocprim::block_scan_algorithm::using_warp_scan + }; }; struct transform_config_tag @@ -479,14 +483,14 @@ struct warp_sort_config_params struct segmented_radix_sort_config_params { - /// \brief Kernel start parameters. - kernel_config_params kernel_config{}; /// \brief Number of bits in iterations. unsigned int radix_bits = 0; - /// \brief If set to \p true, warp sort can be used to sort the small segments, even if no partitioning happens. - bool enable_unpartitioned_warp_sort = true; + /// \brief Kernel start parameters. + kernel_config_params kernel_config{}; /// \brief Warp sort config params warp_sort_config_params warp_sort_config{}; + /// \brief If set to \p true, warp sort can be used to sort the small segments, even if no partitioning happens. + bool enable_unpartitioned_warp_sort = true; }; } // namespace detail @@ -609,17 +613,17 @@ struct segmented_radix_sort_config : public detail::segmented_radix_sort_config_ constexpr segmented_radix_sort_config() : detail::segmented_radix_sort_config_params{ - SortConfig(), RadixBits, - EnableUnpartitionedWarpSort, + SortConfig(), {warp_sort_config::partitioning_allowed, - warp_sort_config::logical_warp_size_small, - warp_sort_config::items_per_thread_small, - warp_sort_config::block_size_small, - warp_sort_config::partitioning_threshold, - warp_sort_config::logical_warp_size_medium, - warp_sort_config::items_per_thread_medium, - warp_sort_config::block_size_medium} + warp_sort_config::logical_warp_size_small, + warp_sort_config::items_per_thread_small, + warp_sort_config::block_size_small, + warp_sort_config::partitioning_threshold, + warp_sort_config::logical_warp_size_medium, + warp_sort_config::items_per_thread_medium, + warp_sort_config::block_size_medium}, + EnableUnpartitionedWarpSort } {} #endif @@ -628,18 +632,17 @@ struct segmented_radix_sort_config : public detail::segmented_radix_sort_config_ namespace detail { /// \brief Default segmented_radix_sort kernel configurations, such that the maximum shared memory is not exceeded. -/// -/// \tparam RadixBits Bits used during the sorting. -/// \tparam ItemsPerThread Items per thread when type Key has size 1. -template -struct default_segmented_radix_sort_config_base +template +constexpr segmented_radix_sort_config_params segmented_radix_sort_config_params_base() { - static constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div( - sizeof(unsigned int) + sizeof(unsigned int), sizeof(int)); - using type = segmented_radix_sort_config, - WarpSortConfig<32, 4, 256, 3000, 32, 4, 256>, - true>; + constexpr unsigned int radix_bits = 6; + + return segmented_radix_sort_config_params{ + radix_bits, + kernel_config_params{128, 17u}, + warp_sort_config_params{true, 32, 4, 256, 3000, 32, 4, 256}, + true + }; }; } // namespace detail @@ -1474,15 +1477,16 @@ struct search_n_config : public detail::search_n_config_params namespace detail { template -struct default_search_n_config_base +constexpr search_n_config_params search_n_config_params_base() { - static constexpr unsigned int item_scale + constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div(sizeof(InputType), sizeof(int)); - using type - = search_n_config::value, - ::rocprim::max(1u, 10u / item_scale), - 8>; + return search_n_config_params{ + {limit_block_size<256u, sizeof(InputType), ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 10u / item_scale)}, + 8 + }; }; } // namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/device_search_n.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/device_search_n.hpp index f41c546c9ad..549d19eaeca 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/device_search_n.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/device_search_n.hpp @@ -65,7 +65,8 @@ hipError_t search_n_impl(void* temporary_storage, { using input_type = typename std::iterator_traits::value_type; using output_type = typename std::iterator_traits::value_type; - using config = wrapped_search_n_config; + + using Selector = search_n_config_selector; // The `size` must greater than or equal to `count` if(count > size) @@ -75,8 +76,12 @@ hipError_t search_n_impl(void* temporary_storage, target_arch target_arch; ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); + gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + + const target get_target(target_arch, target_gpu); - const auto params = dispatch_target_arch(target_arch); + const auto params = get_config(Config{}, get_target); const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_thread = params.kernel_config.items_per_thread; const unsigned int items_per_block = block_size * items_per_thread; @@ -164,12 +169,12 @@ hipError_t search_n_impl(void* temporary_storage, } } }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - search_n_normal_kernel, - num_blocks, - block_size, - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(get_target, + search_n_normal_kernel, + num_blocks, + block_size, + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("search_n_normal_kernel", size, start); ROCPRIM_RETURN_ON_ERROR( transform(tmp_output, output, 1, identity(), stream, debug_synchronous)); @@ -244,12 +249,12 @@ hipError_t search_n_impl(void* temporary_storage, } } }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - search_n_find_heads_kernel, - num_blocks_for_find_heads, - block_size, - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(get_target, + search_n_find_heads_kernel, + num_blocks_for_find_heads, + block_size, + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("search_n_find_heads_kernel", possible_head_exist_size, start); @@ -294,12 +299,12 @@ hipError_t search_n_impl(void* temporary_storage, filtered_heads[atomic_add(tmp_output, 1)] = this_head; } }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - search_n_heads_filter_kernel, - num_blocks_for_heads_filter, - block_size, - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(get_target, + search_n_heads_filter_kernel, + num_blocks_for_heads_filter, + block_size, + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("search_n_heads_filter_kernel", num_groups, start); size_t h_filtered_heads_size = 0; @@ -387,12 +392,12 @@ hipError_t search_n_impl(void* temporary_storage, } } }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - search_n_discard_heads_kernel, - num_blocks_for_discard_heads, - block_size, - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(get_target, + search_n_discard_heads_kernel, + num_blocks_for_discard_heads, + block_size, + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("search_n_discard_heads_kernel ", h_filtered_heads_size, start); diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_scan.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_scan.hpp index 9c05485e49b..82e06315632 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_scan.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_scan.hpp @@ -86,12 +86,18 @@ inline auto scan_impl(void* temporary_storage, scan_state_type scan_state; block_id_type block_id; - using config = wrapped_scan_config; + using Selector = scan_config_selector; detail::target_arch target_arch; ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); - const auto params = dispatch_target_arch(target_arch); + gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + + const target get_target(target_arch, target_gpu); + + const auto params = get_config(Config{}, get_target); + const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_thread = params.kernel_config.items_per_thread; const unsigned int items_per_block = block_size * items_per_thread; @@ -213,12 +219,13 @@ inline auto scan_impl(void* temporary_storage, (number_of_launch > 1), block_id); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - lookback_scan_kernel, - dim3(grid_size), - dim3(block_size), - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR( + execute_launch_plan(get_target, + lookback_scan_kernel, + dim3(grid_size), + dim3(block_size), + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("lookback_scan_kernel", current_size, start); @@ -285,12 +292,12 @@ inline auto scan_impl(void* temporary_storage, // Save values into output array block_store_type().store(output, values, size, storage.store); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - single_scan_kernel, - dim3(1), - dim3(block_size), - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(get_target, + single_scan_kernel, + dim3(1), + dim3(block_size), + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("single_scan_kernel", size, start); } return hipSuccess; diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_scan_by_key.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_scan_by_key.hpp index 7b2862060ee..56ef2e9baed 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_scan_by_key.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_scan_by_key.hpp @@ -171,12 +171,15 @@ inline hipError_t scan_by_key_impl(void* const temporary_storage, ROCPRIM_RETURN_ON_ERROR(std::visit( [&](auto use_sleepy_scan, auto use_atomic_block_id) { - using config = wrapped_scan_by_key_config; + using Selector = scan_by_key_config_selector; detail::target_arch target_arch; ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); - const scan_by_key_config_params params - = dispatch_target_arch(target_arch); + detail::gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + + const target get_target(target_arch, target_gpu); + const auto params = get_config(Config{}, get_target); using wrapped_type = ::rocprim::tuple; using scan_state_type = detail::lookback_scan_state; @@ -336,12 +339,13 @@ inline hipError_t scan_by_key_impl(void* const temporary_storage, last_keys_of_each_block, ordered_bid); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - device_scan_by_key_kernel, - dim3(scan_blocks), - dim3(block_size), - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR( + execute_launch_plan(get_target, + device_scan_by_key_kernel, + dim3(scan_blocks), + dim3(block_size), + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("device_scan_by_key_kernel", current_size, start); diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_scan_by_key_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_scan_by_key_config.hpp index 7018b874d8f..5f1800ef8ac 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_scan_by_key_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_scan_by_key_config.hpp @@ -32,42 +32,20 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -template -struct wrapped_scan_by_key_config +template +struct scan_by_key_config_selector { - static_assert(std::is_same::value, - "Config must be a specialization of struct template scan_by_key_config"); + using targets = scan_by_key_targets; + using param_type = scan_by_key_config_params; - template - struct architecture_config - { - static constexpr scan_by_key_config_params params = ScanByKeyConfig{}; - }; -}; + param_type params; -template -struct wrapped_scan_by_key_config -{ - template - struct architecture_config - { - static constexpr scan_by_key_config_params params - = default_scan_by_key_config(Arch), Key, Value>{}; - }; + template + constexpr scan_by_key_config_selector(Target) + : params(scan_by_key_config_picker()) + {} }; -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template -template -constexpr scan_by_key_config_params - wrapped_scan_by_key_config::architecture_config::params; - -template -template -constexpr scan_by_key_config_params - wrapped_scan_by_key_config::architecture_config::params; -#endif // DOXYGEN_SHOULD_SKIP_THIS - } // namespace detail END_ROCPRIM_NAMESPACE diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_scan_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_scan_config.hpp index f2a4254da20..10289469d7d 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_scan_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_scan_config.hpp @@ -32,40 +32,18 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -template -struct wrapped_scan_config +template +struct scan_config_selector { - static_assert(std::is_same::value, - "Config must be a specialization of struct template scan_config"); - template - struct architecture_config - { - static constexpr scan_config_params params = ScanConfig{}; - }; -}; + using targets = scan_targets; + using param_type = scan_config_params; -template -struct wrapped_scan_config -{ - template - struct architecture_config - { - static constexpr scan_config_params params - = default_scan_config(Arch), Value>{}; - }; -}; + param_type params; -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template -template -constexpr scan_config_params - wrapped_scan_config::architecture_config::params; - -template -template -constexpr scan_config_params - wrapped_scan_config::architecture_config::params; -#endif // DOXYGEN_SHOULD_SKIP_THIS + template + constexpr scan_config_selector(Target) : params(scan_config_picker()) + {} +}; } // namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_search_n_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_search_n_config.hpp index 1d645eed3f6..1a4145ea184 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_search_n_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_search_n_config.hpp @@ -30,40 +30,18 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -// generic struct that instantiates custom configurations -template -struct wrapped_search_n_config +template +struct search_n_config_selector { - template - struct architecture_config - { - static constexpr search_n_config_params params = Config{}; - }; -}; + using targets = search_n_targets; + using param_type = search_n_config_params; -// specialized for rocprim::default_config, which instantiates the default_search_n_config -template -struct wrapped_search_n_config -{ - template - struct architecture_config - { - static constexpr search_n_config_params params - = default_search_n_config(Arch), Value>{}; - }; -}; + param_type params; -#ifndef DOXYGEN_DOCUMENTATION_BUILD -template -template -constexpr search_n_config_params - wrapped_search_n_config::architecture_config::params; - -template -template -constexpr search_n_config_params - wrapped_search_n_config::architecture_config::params; -#endif // DOXYGEN_DOCUMENTATION_BUILD + template + constexpr search_n_config_selector(Target) : params(search_n_config_picker()) + {} +}; } // namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_segmented_radix_sort.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_segmented_radix_sort.hpp index d3d4290d4b1..98f7a252b07 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_segmented_radix_sort.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_segmented_radix_sort.hpp @@ -150,17 +150,17 @@ inline hipError_t segmented_radix_sort_impl( typename std::iterator_traits::value_type>::value, "ValuesInputIterator and ValuesOutputIterator must have the same value_type"); - using config = wrapped_segmented_radix_sort_config; + using Selector = segmented_radix_sort_config_selector; detail::target_arch target_arch; - hipError_t result = detail::host_target_arch(stream, target_arch); - if(result != hipSuccess) - { - return result; - } + ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); + + gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + + const target get_target(target_arch, target_gpu); - const detail::segmented_radix_sort_config_params params - = detail::dispatch_target_arch(target_arch); + const auto params = get_config(Config{}, get_target); static constexpr bool with_values = !std::is_same::value; const bool partitioning_allowed = params.warp_sort_config.partitioning_allowed; @@ -353,12 +353,12 @@ inline hipError_t segmented_radix_sort_impl( }; ROCPRIM_RETURN_ON_ERROR( - execute_launch_plan(target_arch, - segmented_sort_large_kernel, - dim3(large_segment_count), - dim3(params.kernel_config.block_size), - 0, - stream)); + execute_launch_plan(get_target, + segmented_sort_large_kernel, + dim3(large_segment_count), + dim3(params.kernel_config.block_size), + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("segmented_sort:large_segments", large_segment_count, start); @@ -391,10 +391,10 @@ inline hipError_t segmented_radix_sort_impl( }; ROCPRIM_RETURN_ON_ERROR( - execute_launch_plan( - target_arch, + execute_launch_plan( + get_target, segmented_sort_medium_kernel, dim3(medium_segment_grid_size), dim3(params.warp_sort_config.block_size_medium), @@ -432,10 +432,10 @@ inline hipError_t segmented_radix_sort_impl( }; ROCPRIM_RETURN_ON_ERROR( - execute_launch_plan( - target_arch, + execute_launch_plan( + get_target, segmented_sort_small_kernel, dim3(small_segment_grid_size), dim3(params.warp_sort_config.block_size_small), @@ -469,12 +469,13 @@ inline hipError_t segmented_radix_sort_impl( end_bit); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - segmented_sort_kernel, - dim3(segments), - dim3(params.kernel_config.block_size), - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR( + execute_launch_plan(get_target, + segmented_sort_kernel, + dim3(segments), + dim3(params.kernel_config.block_size), + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("segmented_sort", segments, start); } return hipSuccess; diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_segmented_radix_sort_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_segmented_radix_sort_config.hpp index f4a1c8de9f3..e267ae58dca 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_segmented_radix_sort_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_segmented_radix_sort_config.hpp @@ -57,46 +57,33 @@ using select_warp_sort_config_t namespace detail { -template -struct wrapped_segmented_radix_sort_config +template +struct segmented_radix_sort_config_selector { - static_assert(std::is_same::value, - "Config must be a specialization of struct template segmented_radix_sort_config"); - - template - struct architecture_config - { - static constexpr detail::segmented_radix_sort_config_params params - = SegmentedRadixSortConfig{}; - }; + using targets = segmented_radix_sort_targets; + using param_type = segmented_radix_sort_config_params; + + param_type params; + + template + constexpr segmented_radix_sort_config_selector(Target) + : params(segmented_radix_sort_config_picker()) + {} }; -template -struct wrapped_segmented_radix_sort_config +template +struct segmented_radix_sort_warp_sort_small_config_static_selector { - template - struct architecture_config - { - static constexpr segmented_radix_sort_config_params params - = detail::default_segmented_radix_sort_config(Arch), - key_type, - value_type>{}; - }; + static constexpr auto block_size + = target_config2::params.warp_sort_config.block_size_small; }; -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template -template -constexpr segmented_radix_sort_config_params - wrapped_segmented_radix_sort_config:: - architecture_config::params; -template -template -constexpr segmented_radix_sort_config_params - wrapped_segmented_radix_sort_config:: - architecture_config::params; -#endif // DOXYGEN_SHOULD_SKIP_THIS +template +struct segmented_radix_sort_warp_sort_medium_config_static_selector +{ + static constexpr auto block_size + = target_config2::params.warp_sort_config.block_size_medium; +}; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_segmented_scan.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_segmented_scan.hpp index cac12f2ef82..df3af814961 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_segmented_scan.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_segmented_scan.hpp @@ -94,15 +94,17 @@ inline hipError_t segmented_scan_impl(void* temporary_storage, using input_type = typename std::iterator_traits::value_type; using result_type = typename std::conditional::type; - using config = wrapped_scan_config; + using Selector = detail::scan_config_selector; detail::target_arch target_arch; - hipError_t result = host_target_arch(stream, target_arch); - if(result != hipSuccess) - { - return result; - } - const scan_config_params params = dispatch_target_arch(target_arch); + ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); + + detail::gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + + const target get_target(target_arch, target_gpu); + + const auto params = get_config(Config{}, get_target); const unsigned int block_size = params.kernel_config.block_size; @@ -133,12 +135,12 @@ inline hipError_t segmented_scan_impl(void* temporary_storage, scan_op); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - segmented_scan_kernel, - dim3(segments), - dim3(block_size), - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(get_target, + segmented_scan_kernel, + dim3(segments), + dim3(block_size), + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("segmented_scan", segments, start); return hipSuccess; } diff --git a/projects/rocprim/test/rocprim/test_device_scan.cpp b/projects/rocprim/test/rocprim/test_device_scan.cpp index 8327b89fe63..e8d5badfc2e 100644 --- a/projects/rocprim/test/rocprim/test_device_scan.cpp +++ b/projects/rocprim/test/rocprim/test_device_scan.cpp @@ -305,15 +305,20 @@ TYPED_TEST(RocprimDeviceScanTests, LookBackScan) const bool deterministic = TestFixture::deterministic; const bool use_initial_value = TestFixture::use_initial_value; - using Config = typename TestFixture::config_helper; - using config = rocprim::detail::wrapped_scan_config; + using Config = typename TestFixture::config_helper; + using Selector = rocprim::detail::scan_config_selector; hipStream_t stream = hipStreamDefault; rocprim::detail::target_arch target_arch; - HIP_CHECK(host_target_arch(stream, target_arch)); - const rocprim::detail::scan_config_params params - = rocprim::detail::dispatch_target_arch(target_arch); + HIP_CHECK(rocprim::detail::host_target_arch(stream, target_arch)); + + rocprim::detail::gpu target_gpu; + HIP_CHECK(rocprim::detail::host_target_gpu(stream, target_gpu)); + + const rocprim::detail::target get_target(target_arch, target_gpu); + + const auto params = rocprim::detail::get_config(Config{}, get_target); // For non-associative operations in inclusive scan // intermediate results use the type of input iterator, then @@ -497,12 +502,13 @@ TYPED_TEST(RocprimDeviceScanTests, LookBackScan) false, ordered_bid); }; - return rocprim::detail::execute_launch_plan(target_arch, - lookback_scan_kernel, - dim3(grid_size), - dim3(block_size), - 0, - stream); + return rocprim::detail::execute_launch_plan( + get_target, + lookback_scan_kernel, + dim3(grid_size), + dim3(block_size), + 0, + stream); }); ASSERT_EQ(hipSuccess, launch_err); @@ -553,15 +559,20 @@ TYPED_TEST(RocprimDeviceScanTests, LookBackScanGetCompleteValue) const bool deterministic = TestFixture::deterministic; const bool use_initial_value = TestFixture::use_initial_value; - using Config = typename TestFixture::config_helper; - using config = rocprim::detail::wrapped_scan_config; + using Config = typename TestFixture::config_helper; + using Selector = rocprim::detail::scan_config_selector; hipStream_t stream = hipStreamDefault; rocprim::detail::target_arch target_arch; - HIP_CHECK(host_target_arch(stream, target_arch)); - const rocprim::detail::scan_config_params params - = rocprim::detail::dispatch_target_arch(target_arch); + HIP_CHECK(rocprim::detail::host_target_arch(stream, target_arch)); + + rocprim::detail::gpu target_gpu; + HIP_CHECK(rocprim::detail::host_target_gpu(stream, target_gpu)); + + const rocprim::detail::target get_target(target_arch, target_gpu); + + const auto params = rocprim::detail::get_config(Config{}, get_target); // For non-associative operations in inclusive scan // intermediate results use the type of input iterator, then @@ -735,12 +746,12 @@ TYPED_TEST(RocprimDeviceScanTests, LookBackScanGetCompleteValue) false, ordered_bid); }; - return rocprim::detail::execute_launch_plan(target_arch, - lookback_scan_kernel, - dim3(grid_size), - dim3(block_size), - 0, - stream); + return rocprim::detail::execute_launch_plan(get_target, + lookback_scan_kernel, + dim3(grid_size), + dim3(block_size), + 0, + stream); }); ASSERT_EQ(hipSuccess, launch_err); diff --git a/projects/rocprim/test/rocprim/test_device_search_n.cpp b/projects/rocprim/test/rocprim/test_device_search_n.cpp index c7622e319fc..07387e453e8 100644 --- a/projects/rocprim/test/rocprim/test_device_search_n.cpp +++ b/projects/rocprim/test/rocprim/test_device_search_n.cpp @@ -904,13 +904,18 @@ TYPED_TEST(RocprimDeviceSearchNTests, NoiseTest_1block) for(const auto size : test_utils::get_sizes(seed_value)) { - using wrapped_config = rocprim::detail::wrapped_search_n_config; - size_t temp_storage_size; - hipStream_t stream = 0; // default + using Selector = rocprim::detail::search_n_config_selector; + size_t temp_storage_size; + hipStream_t stream = 0; // default + rocprim::detail::target_arch target_arch; HIP_CHECK(rocprim::detail::host_target_arch(stream, target_arch)); - const auto params - = rocprim::detail::dispatch_target_arch(target_arch); + rocprim::detail::gpu target_gpu; + HIP_CHECK(rocprim::detail::host_target_gpu(stream, target_gpu)); + + const rocprim::detail::target get_target(target_arch, target_gpu); + + const auto params = rocprim::detail::get_config(config{}, get_target); const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_thread = params.kernel_config.items_per_thread; const unsigned int items_per_block = block_size * items_per_thread; @@ -1025,13 +1030,18 @@ TYPED_TEST(RocprimDeviceSearchNTests, NoiseTest_2block) for(const auto size : test_utils::get_sizes(seed_value)) { - using wrapped_config = rocprim::detail::wrapped_search_n_config; - size_t temp_storage_size; - hipStream_t stream = 0; // default + using Selector = rocprim::detail::search_n_config_selector; + size_t temp_storage_size; + hipStream_t stream = 0; // default + rocprim::detail::target_arch target_arch; HIP_CHECK(rocprim::detail::host_target_arch(stream, target_arch)); - const auto params - = rocprim::detail::dispatch_target_arch(target_arch); + rocprim::detail::gpu target_gpu; + HIP_CHECK(rocprim::detail::host_target_gpu(stream, target_gpu)); + + const rocprim::detail::target get_target(target_arch, target_gpu); + + const auto params = rocprim::detail::get_config(config{}, get_target); const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_thread = params.kernel_config.items_per_thread; const unsigned int items_per_block = block_size * items_per_thread; @@ -1146,13 +1156,18 @@ TYPED_TEST(RocprimDeviceSearchNTests, NoiseTest_3block) for(const auto size : test_utils::get_sizes(seed_value)) { - using wrapped_config = rocprim::detail::wrapped_search_n_config; - size_t temp_storage_size; - hipStream_t stream = 0; // default + using Selector = rocprim::detail::search_n_config_selector; + size_t temp_storage_size; + hipStream_t stream = 0; // default + rocprim::detail::target_arch target_arch; HIP_CHECK(rocprim::detail::host_target_arch(stream, target_arch)); - const auto params - = rocprim::detail::dispatch_target_arch(target_arch); + rocprim::detail::gpu target_gpu; + HIP_CHECK(rocprim::detail::host_target_gpu(stream, target_gpu)); + + const rocprim::detail::target get_target(target_arch, target_gpu); + + const auto params = rocprim::detail::get_config(config{}, get_target); const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_thread = params.kernel_config.items_per_thread; const unsigned int items_per_block = block_size * items_per_thread; @@ -1267,13 +1282,18 @@ TYPED_TEST(RocprimDeviceSearchNTests, MultiResult1) for(const auto size : test_utils::get_sizes(seed_value)) { - using wrapped_config = rocprim::detail::wrapped_search_n_config; - size_t temp_storage_size; - hipStream_t stream = 0; // default + using Selector = rocprim::detail::search_n_config_selector; + size_t temp_storage_size; + hipStream_t stream = 0; // default + rocprim::detail::target_arch target_arch; HIP_CHECK(rocprim::detail::host_target_arch(stream, target_arch)); - const auto params - = rocprim::detail::dispatch_target_arch(target_arch); + rocprim::detail::gpu target_gpu; + HIP_CHECK(rocprim::detail::host_target_gpu(stream, target_gpu)); + + const rocprim::detail::target get_target(target_arch, target_gpu); + + const auto params = rocprim::detail::get_config(config{}, get_target); const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_thread = params.kernel_config.items_per_thread; const unsigned int items_per_block = block_size * items_per_thread; @@ -1389,13 +1409,18 @@ TYPED_TEST(RocprimDeviceSearchNTests, MultiResult2) = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; for(const auto size : test_utils::get_sizes(seed_value)) { - using wrapped_config = rocprim::detail::wrapped_search_n_config; - size_t temp_storage_size; - hipStream_t stream = 0; // default + using Selector = rocprim::detail::search_n_config_selector; + size_t temp_storage_size; + hipStream_t stream = 0; // default + rocprim::detail::target_arch target_arch; HIP_CHECK(rocprim::detail::host_target_arch(stream, target_arch)); - const auto params - = rocprim::detail::dispatch_target_arch(target_arch); + rocprim::detail::gpu target_gpu; + HIP_CHECK(rocprim::detail::host_target_gpu(stream, target_gpu)); + + const rocprim::detail::target get_target(target_arch, target_gpu); + + const auto params = rocprim::detail::get_config(config{}, get_target); const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_thread = params.kernel_config.items_per_thread; const unsigned int items_per_block = block_size * items_per_thread; diff --git a/projects/rocprim/test/rocprim/test_linking_new_scan.hpp b/projects/rocprim/test/rocprim/test_linking_new_scan.hpp index ef51bcb30ea..ba048b02204 100644 --- a/projects/rocprim/test/rocprim/test_linking_new_scan.hpp +++ b/projects/rocprim/test/rocprim/test_linking_new_scan.hpp @@ -98,15 +98,17 @@ inline auto scan_impl(void* temporary_storage, { (void)debug_synchronous; - using config = wrapped_scan_config; + using Selector = scan_config_selector; - detail::target_arch target_arch; - hipError_t result = host_target_arch(stream, target_arch); - if(result != hipSuccess) - { - return result; - } - const scan_config_params params = dispatch_target_arch(target_arch); + rocprim::detail::target_arch target_arch; + ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); + + rocprim::detail::gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + + const target get_target(target_arch, target_gpu); + + const auto params = get_config(Config{}, get_target); const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_thread = params.kernel_config.items_per_thread; @@ -133,12 +135,12 @@ inline auto scan_impl(void* temporary_storage, output, scan_op); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - single_scan_kernel, - dim3(1), - dim3(block_size), - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(get_target, + single_scan_kernel, + dim3(1), + dim3(block_size), + 0, + stream)); return hipGetLastError(); } From 99c7297d2b032293c945fe2fe7b4e9954a77bb67 Mon Sep 17 00:00:00 2001 From: Nick Breed Date: Thu, 6 Nov 2025 10:48:33 +0000 Subject: [PATCH 03/26] Resolve "Update configs for new config system part 4" --- .../include/rocprim/device/config_types.hpp | 7 - .../config/device_radix_sort_block_sort.hpp | 7363 +++++-------- .../device/detail/config/device_reduce.hpp | 1349 ++- .../detail/config/device_reduce_by_key.hpp | 9301 +++++++---------- .../config/device_run_length_encode.hpp | 1849 ++-- .../device_run_length_encode_non_trivial.hpp | 1643 ++- .../detail/config/device_segmented_reduce.hpp | 1225 +-- .../device/detail/device_config_helper.hpp | 119 +- .../rocprim/device/device_radix_sort.hpp | 8 +- .../device/device_radix_sort_config.hpp | 43 +- .../include/rocprim/device/device_reduce.hpp | 35 +- .../rocprim/device/device_reduce_by_key.hpp | 53 +- .../device/device_reduce_by_key_config.hpp | 94 +- .../rocprim/device/device_reduce_config.hpp | 39 +- .../device/device_run_length_encode.hpp | 59 +- .../device_run_length_encode_config.hpp | 224 +- .../device/device_segmented_reduce.hpp | 30 +- .../device/device_segmented_reduce_config.hpp | 40 +- .../device_radix_block_sort.hpp | 34 +- .../device_radix_merge_sort.hpp | 8 +- 20 files changed, 9412 insertions(+), 14111 deletions(-) diff --git a/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp b/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp index 756be84f7b0..ab20e0922c6 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp @@ -477,13 +477,6 @@ struct merge_mergepath_config_selector = Config::template architecture_config::params.merge_mergepath_config.block_size; }; -template -struct radix_sort_config_selector -{ - static constexpr unsigned int block_size - = Config::template architecture_config::params.block_size; -}; - template struct target_config { diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_radix_sort_block_sort.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_radix_sort_block_sort.hpp index 9bd1f70448d..e29d1eeba08 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_radix_sort_block_sort.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_radix_sort_block_sort.hpp @@ -40,4747 +40,2628 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_radix_sort_block_sort_config - : radix_sort_block_sort_config_base::type -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<128, 16> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<64, 25> -{}; - -// Based on key_type = double, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<64, 25> -{}; - -// Based on key_type = double, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<128, 25> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<128, 28> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : kernel_config<64, 32> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 15> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 25> -{}; - -// Based on key_type = float, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<128, 29> -{}; - -// Based on key_type = float, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<128, 29> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 31> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : kernel_config<64, 32> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : kernel_config<256, 15> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<256, 25> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<256, 29> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<128, 29> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<128, 29> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : kernel_config<128, 32> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<128, 13> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<128, 15> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<128, 16> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<128, 16> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 18> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : kernel_config<128, 16> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<128, 16> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<64, 25> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<128, 28> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<128, 28> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 28> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : kernel_config<64, 28> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 4> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 25> -{}; - -// Based on key_type = int, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<128, 28> -{}; - -// Based on key_type = int, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<128, 30> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<128, 30> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : kernel_config<64, 32> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 15> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 25> -{}; - -// Based on key_type = short, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 30> -{}; - -// Based on key_type = short, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<128, 29> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<128, 30> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : kernel_config<64, 32> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : kernel_config<256, 15> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<512, 15> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<512, 30> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<256, 30> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 31> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : kernel_config<64, 32> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<128, 16> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<128, 25> -{}; - -// Based on key_type = double, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<128, 25> -{}; - -// Based on key_type = double, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<128, 25> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 28> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : kernel_config<128, 32> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 15> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<128, 25> -{}; - -// Based on key_type = float, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<64, 31> -{}; - -// Based on key_type = float, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<64, 30> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 30> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : kernel_config<128, 32> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : kernel_config<256, 15> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<256, 25> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<256, 23> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<128, 31> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 27> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : kernel_config<64, 32> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 10> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 15> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 15> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 15> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<128, 16> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : kernel_config<256, 15> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<128, 16> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<128, 25> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<128, 25> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<128, 25> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 28> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : kernel_config<128, 32> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 15> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<128, 25> -{}; - -// Based on key_type = int, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<64, 31> -{}; - -// Based on key_type = int, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<64, 30> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 30> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : kernel_config<64, 30> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 15> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 25> -{}; - -// Based on key_type = short, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<128, 30> -{}; - -// Based on key_type = short, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<64, 31> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<128, 30> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : kernel_config<64, 32> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : kernel_config<256, 15> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<256, 16> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<256, 30> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<128, 30> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 31> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : kernel_config<64, 32> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<64, 28> -{}; - -// Based on key_type = double, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<128, 25> -{}; - -// Based on key_type = double, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<128, 25> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 28> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : kernel_config<128, 32> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<128, 25> -{}; - -// Based on key_type = float, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<64, 30> -{}; - -// Based on key_type = float, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<64, 27> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 26> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : kernel_config<128, 32> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<256, 24> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<128, 29> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<128, 29> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<128, 27> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : kernel_config<128, 32> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<64, 28> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<128, 30> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<64, 28> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 28> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : kernel_config<64, 32> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<128, 25> -{}; - -// Based on key_type = int, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<64, 31> -{}; - -// Based on key_type = int, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<64, 30> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 29> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : kernel_config<128, 32> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 24> -{}; - -// Based on key_type = short, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<128, 29> -{}; - -// Based on key_type = short, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<64, 29> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 27> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : kernel_config<64, 32> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<256, 30> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<256, 29> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<128, 30> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 27> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : kernel_config<64, 31> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<64, 22> -{}; - -// Based on key_type = double, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<64, 25> -{}; - -// Based on key_type = double, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<64, 22> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 28> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : kernel_config<128, 25> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<64, 25> -{}; - -// Based on key_type = float, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<64, 30> -{}; - -// Based on key_type = float, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<64, 27> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 26> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : kernel_config<64, 32> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<256, 25> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<128, 27> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<64, 29> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 27> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : kernel_config<64, 30> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<64, 22> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<64, 28> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<64, 28> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 28> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : kernel_config<64, 32> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<64, 25> -{}; - -// Based on key_type = int, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<64, 31> -{}; - -// Based on key_type = int, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<64, 30> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 29> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : kernel_config<64, 32> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 25> -{}; - -// Based on key_type = short, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<128, 29> -{}; - -// Based on key_type = short, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<64, 29> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 27> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : kernel_config<64, 32> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<512, 15> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<256, 29> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<128, 30> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 27> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : kernel_config<64, 31> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 7> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 12> -{}; - -// Based on key_type = double, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 12> -{}; - -// Based on key_type = double, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 13> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 12> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : kernel_config<256, 16> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 8> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 12> -{}; - -// Based on key_type = float, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 18> -{}; - -// Based on key_type = float, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 16> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 14> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : kernel_config<512, 23> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : kernel_config<256, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<256, 12> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<256, 15> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<256, 19> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 15> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : kernel_config<256, 22> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 6> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 7> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : kernel_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 11> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 11> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 12> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 12> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : kernel_config<256, 16> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 4> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 12> -{}; - -// Based on key_type = int, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 19> -{}; - -// Based on key_type = int, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 16> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 15> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : kernel_config<512, 23> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 8> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 12> -{}; - -// Based on key_type = short, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 14> -{}; - -// Based on key_type = short, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 19> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 15> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : kernel_config<256, 23> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : kernel_config<256, 15> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<256, 13> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<256, 10> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<256, 16> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 17> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : kernel_config<256, 22> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 7> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 11> -{}; - -// Based on key_type = double, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 11> -{}; - -// Based on key_type = double, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 12> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<512, 15> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : kernel_config<256, 16> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 8> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<512, 15> -{}; - -// Based on key_type = float, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<512, 30> -{}; - -// Based on key_type = float, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<1024, 15> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<512, 29> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : kernel_config<512, 31> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : kernel_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<512, 15> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<512, 28> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<1024, 18> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<512, 27> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : kernel_config<1024, 22> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 7> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : kernel_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<512, 15> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 11> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 8> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 11> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : kernel_config<256, 16> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 4> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<512, 14> -{}; - -// Based on key_type = int, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<512, 31> -{}; - -// Based on key_type = int, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<512, 30> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<512, 29> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : kernel_config<512, 31> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<512, 4> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 12> -{}; - -// Based on key_type = short, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<512, 30> -{}; - -// Based on key_type = short, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<1024, 18> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<1024, 16> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : kernel_config<1024, 22> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : kernel_config<256, 15> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<256, 13> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<512, 30> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<1024, 16> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<1024, 15> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : kernel_config<1024, 23> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 8> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 10> -{}; - -// Based on key_type = double, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 16> -{}; - -// Based on key_type = double, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 16> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 16> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : kernel_config<256, 15> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 4> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 15> -{}; - -// Based on key_type = float, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 16> -{}; - -// Based on key_type = float, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 16> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 14> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : kernel_config<256, 21> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : kernel_config<256, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<256, 10> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<256, 29> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<512, 18> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 29> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : kernel_config<256, 32> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : kernel_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 10> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 16> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 16> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 16> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : kernel_config<256, 15> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 4> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 10> -{}; - -// Based on key_type = int, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 16> -{}; - -// Based on key_type = int, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 16> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 30> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : kernel_config<256, 21> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 8> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 10> -{}; - -// Based on key_type = short, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 30> -{}; - -// Based on key_type = short, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 23> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 30> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : kernel_config<256, 27> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : kernel_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<512, 5> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<256, 30> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<256, 21> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 23> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : kernel_config<256, 29> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 7> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 11> -{}; - -// Based on key_type = double, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 11> -{}; - -// Based on key_type = double, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 12> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<512, 15> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : kernel_config<256, 16> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 8> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<512, 15> -{}; - -// Based on key_type = float, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<512, 30> -{}; - -// Based on key_type = float, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<1024, 15> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<512, 29> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : kernel_config<512, 31> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : kernel_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<512, 15> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<512, 28> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<1024, 18> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<512, 27> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : kernel_config<1024, 22> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 7> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : kernel_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<512, 15> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 11> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 8> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 11> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : kernel_config<256, 16> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 4> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<512, 14> -{}; - -// Based on key_type = int, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<512, 31> -{}; - -// Based on key_type = int, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<512, 30> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<512, 29> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : kernel_config<512, 31> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<512, 4> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 12> -{}; - -// Based on key_type = short, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<512, 30> -{}; - -// Based on key_type = short, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<1024, 18> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<1024, 16> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : kernel_config<1024, 22> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : kernel_config<256, 15> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<256, 13> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<512, 30> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<1024, 16> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<1024, 15> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : kernel_config<1024, 23> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 8> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 8> -{}; - -// Based on key_type = double, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 8> -{}; - -// Based on key_type = double, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 8> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 8> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : kernel_config<256, 8> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 8> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 8> -{}; - -// Based on key_type = float, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 21> -{}; - -// Based on key_type = float, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 16> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 29> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : kernel_config<256, 32> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : kernel_config<256, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<512, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<512, 14> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<512, 18> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<512, 14> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : kernel_config<512, 21> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 5> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 5> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 5> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 5> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : kernel_config<256, 8> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 8> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 8> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 16> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 8> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 8> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : kernel_config<256, 8> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 8> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 10> -{}; - -// Based on key_type = int, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 21> -{}; - -// Based on key_type = int, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<256, 16> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 19> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : kernel_config<256, 21> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 8> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<256, 16> -{}; - -// Based on key_type = short, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<256, 29> -{}; - -// Based on key_type = short, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<512, 18> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<512, 15> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : kernel_config<512, 22> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : kernel_config<512, 7> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : kernel_config<1024, 4> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : kernel_config<256, 21> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : kernel_config<256, 21> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<256, 23> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : kernel_config<512, 21> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<64, 16> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 15> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : kernel_config<256, 15> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<128, 16> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : kernel_config<64, 17> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : kernel_config<64, 16> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : kernel_config<64, 16> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : kernel_config<64, 16> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : kernel_config<64, 16> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<64, 16> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 15> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : kernel_config<256, 15> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_radix_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : kernel_config<256, 15> -{}; +template +constexpr auto radix_sort_block_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + kernel_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{128, 16}; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{64, 25}; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{64, 25}; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{128, 25}; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{128, 28}; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return kernel_config_params{64, 32}; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 25}; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{128, 29}; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{128, 29}; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 31}; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return kernel_config_params{64, 32}; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 25}; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 29}; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return kernel_config_params{128, 29}; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{128, 29}; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return kernel_config_params{128, 32}; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{128, 13}; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{128, 15}; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{128, 16}; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{128, 16}; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 18}; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return kernel_config_params{128, 16}; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{128, 16}; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{64, 25}; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{128, 28}; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{128, 28}; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 28}; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return kernel_config_params{64, 28}; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 4}; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 25}; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{128, 28}; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{128, 30}; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{128, 30}; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return kernel_config_params{64, 32}; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 25}; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 30}; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{128, 29}; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{128, 30}; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return kernel_config_params{64, 32}; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return kernel_config_params{512, 15}; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return kernel_config_params{512, 30}; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 30}; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 31}; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return kernel_config_params{64, 32}; + } + // Default case if none of the conditions match + return radix_sort_block_sort_config_params_base(); +} + +template +constexpr auto radix_sort_block_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + kernel_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{128, 16}; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{128, 25}; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{128, 25}; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{128, 25}; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 28}; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return kernel_config_params{128, 32}; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{128, 25}; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{64, 31}; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{64, 30}; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 30}; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return kernel_config_params{128, 32}; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 25}; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 23}; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return kernel_config_params{128, 31}; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 27}; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return kernel_config_params{64, 32}; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 10}; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{128, 16}; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{128, 16}; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{128, 25}; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{128, 25}; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{128, 25}; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 28}; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return kernel_config_params{128, 32}; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{128, 25}; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{64, 31}; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{64, 30}; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 30}; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return kernel_config_params{64, 30}; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 25}; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{128, 30}; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{64, 31}; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{128, 30}; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return kernel_config_params{64, 32}; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 16}; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 30}; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return kernel_config_params{128, 30}; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 31}; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return kernel_config_params{64, 32}; + } + // Default case if none of the conditions match + return radix_sort_block_sort_config_params_base(); +} + +template +constexpr auto radix_sort_block_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + kernel_config_params> +{ + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{64, 28}; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{128, 25}; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{128, 25}; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 28}; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return kernel_config_params{128, 32}; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{128, 25}; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{64, 30}; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{64, 27}; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 26}; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return kernel_config_params{128, 32}; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 24}; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return kernel_config_params{128, 29}; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return kernel_config_params{128, 29}; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{128, 27}; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return kernel_config_params{128, 32}; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{64, 28}; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{128, 30}; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{64, 28}; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 28}; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return kernel_config_params{64, 32}; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{128, 25}; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{64, 31}; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{64, 30}; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 29}; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return kernel_config_params{128, 32}; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 24}; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{128, 29}; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{64, 29}; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 27}; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return kernel_config_params{64, 32}; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 30}; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 29}; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return kernel_config_params{128, 30}; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 27}; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return kernel_config_params{64, 31}; + } + // Default case if none of the conditions match + return radix_sort_block_sort_config_params_base(); +} + +template +constexpr auto radix_sort_block_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + kernel_config_params> +{ + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{64, 22}; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{64, 25}; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{64, 22}; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 28}; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return kernel_config_params{128, 25}; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{64, 25}; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{64, 30}; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{64, 27}; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 26}; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return kernel_config_params{64, 32}; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 25}; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return kernel_config_params{128, 27}; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return kernel_config_params{64, 29}; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 27}; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return kernel_config_params{64, 30}; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{64, 22}; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{64, 28}; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{64, 28}; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 28}; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return kernel_config_params{64, 32}; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{64, 25}; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{64, 31}; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{64, 30}; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 29}; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return kernel_config_params{64, 32}; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 25}; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{128, 29}; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{64, 29}; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 27}; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return kernel_config_params{64, 32}; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return kernel_config_params{512, 15}; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 29}; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return kernel_config_params{128, 30}; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 27}; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return kernel_config_params{64, 31}; + } + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{64, 16}; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{128, 16}; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{64, 17}; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{64, 16}; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{64, 16}; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{64, 16}; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return kernel_config_params{64, 16}; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{64, 16}; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 15}; + } + // Default case if none of the conditions match + return radix_sort_block_sort_config_params_base(); +} + +template +constexpr auto radix_sort_block_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + kernel_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 7}; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 12}; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 12}; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 13}; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 12}; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return kernel_config_params{256, 16}; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 12}; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 18}; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 16}; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 14}; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return kernel_config_params{512, 23}; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 12}; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 19}; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return kernel_config_params{256, 22}; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 6}; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 7}; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return kernel_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 11}; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 11}; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 12}; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 12}; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return kernel_config_params{256, 16}; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 4}; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 12}; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 19}; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 16}; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return kernel_config_params{512, 23}; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 12}; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 14}; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 19}; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return kernel_config_params{256, 23}; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 13}; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 10}; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 16}; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 17}; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return kernel_config_params{256, 22}; + } + // Default case if none of the conditions match + return radix_sort_block_sort_config_params_base(); +} + +template +constexpr auto radix_sort_block_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + kernel_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 7}; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 11}; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 11}; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 12}; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{512, 15}; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return kernel_config_params{256, 16}; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{512, 15}; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{512, 30}; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{1024, 15}; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{512, 29}; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return kernel_config_params{512, 31}; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return kernel_config_params{512, 4}; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return kernel_config_params{512, 15}; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return kernel_config_params{512, 28}; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return kernel_config_params{1024, 18}; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{512, 27}; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return kernel_config_params{1024, 22}; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 7}; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return kernel_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{512, 15}; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 11}; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 11}; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return kernel_config_params{256, 16}; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 4}; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{512, 14}; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{512, 31}; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{512, 30}; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{512, 29}; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return kernel_config_params{512, 31}; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{512, 4}; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 12}; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{512, 30}; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{1024, 18}; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{1024, 16}; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return kernel_config_params{1024, 22}; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 13}; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return kernel_config_params{512, 30}; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return kernel_config_params{1024, 16}; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{1024, 15}; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return kernel_config_params{1024, 23}; + } + // Default case if none of the conditions match + return radix_sort_block_sort_config_params_base(); +} + +template +constexpr auto radix_sort_block_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + kernel_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 10}; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 16}; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 16}; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 16}; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 4}; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 16}; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 16}; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 14}; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return kernel_config_params{256, 21}; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 10}; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 29}; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return kernel_config_params{512, 18}; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 29}; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return kernel_config_params{256, 32}; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return kernel_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 10}; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 16}; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 16}; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 16}; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return kernel_config_params{256, 15}; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 4}; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 10}; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 16}; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 16}; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 30}; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return kernel_config_params{256, 21}; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 10}; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 30}; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 23}; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 30}; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return kernel_config_params{256, 27}; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return kernel_config_params{512, 4}; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return kernel_config_params{512, 5}; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 30}; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 21}; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 23}; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return kernel_config_params{256, 29}; + } + // Default case if none of the conditions match + return radix_sort_block_sort_config_params_base(); +} + +template +constexpr auto radix_sort_block_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + kernel_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 21}; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 16}; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 29}; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return kernel_config_params{256, 32}; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return kernel_config_params{512, 8}; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return kernel_config_params{512, 14}; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return kernel_config_params{512, 18}; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{512, 14}; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return kernel_config_params{512, 21}; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 5}; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 5}; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 5}; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 5}; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 16}; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 10}; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 21}; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 16}; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 19}; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return kernel_config_params{256, 21}; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return kernel_config_params{256, 8}; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return kernel_config_params{256, 16}; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 29}; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return kernel_config_params{512, 18}; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{512, 15}; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return kernel_config_params{512, 22}; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return kernel_config_params{512, 7}; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return kernel_config_params{1024, 4}; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return kernel_config_params{256, 21}; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return kernel_config_params{256, 21}; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return kernel_config_params{256, 23}; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return kernel_config_params{512, 21}; + } + // Default case if none of the conditions match + return radix_sort_block_sort_config_params_base(); +} + +template +constexpr auto radix_sort_block_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + kernel_config_params> +{ + return radix_sort_block_sort_config_picker< + comp_target, + key_type, + value_type>(); +} + +// All the existing configs should be auto generated +using radix_sort_block_sort_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_reduce.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_reduce.hpp index 6a8d99cdd52..0ebebb30cd5 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_reduce.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_reduce.hpp @@ -40,702 +40,659 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_reduce_config : default_reduce_config_base::type -{}; - -// Based on key_type = double -template -struct default_reduce_config< - static_cast(target_arch::gfx1030), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 2, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = float -template -struct default_reduce_config< - static_cast(target_arch::gfx1030), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 2, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::half -template -struct default_reduce_config(target_arch::gfx1030), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_config<256, 4, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_reduce_config< - static_cast(target_arch::gfx1030), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_config<256, 2, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int64_t -template -struct default_reduce_config< - static_cast(target_arch::gfx1030), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 1, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int -template -struct default_reduce_config< - static_cast(target_arch::gfx1030), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 2, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = short -template -struct default_reduce_config< - static_cast(target_arch::gfx1030), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_config<256, 4, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int8_t -template -struct default_reduce_config(target_arch::gfx1030), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = double -template -struct default_reduce_config< - static_cast(target_arch::gfx1100), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 4, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = float -template -struct default_reduce_config< - static_cast(target_arch::gfx1100), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 4, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::half -template -struct default_reduce_config(target_arch::gfx1100), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_reduce_config< - static_cast(target_arch::gfx1100), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_config<256, 4, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int64_t -template -struct default_reduce_config< - static_cast(target_arch::gfx1100), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 4, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int -template -struct default_reduce_config< - static_cast(target_arch::gfx1100), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = short -template -struct default_reduce_config< - static_cast(target_arch::gfx1100), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int8_t -template -struct default_reduce_config(target_arch::gfx1100), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_config<128, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = double -template -struct default_reduce_config< - static_cast(target_arch::gfx1200), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = float -template -struct default_reduce_config< - static_cast(target_arch::gfx1200), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::half -template -struct default_reduce_config(target_arch::gfx1200), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int64_t -template -struct default_reduce_config< - static_cast(target_arch::gfx1200), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int -template -struct default_reduce_config< - static_cast(target_arch::gfx1200), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = short -template -struct default_reduce_config< - static_cast(target_arch::gfx1200), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int8_t -template -struct default_reduce_config(target_arch::gfx1200), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = double -template -struct default_reduce_config< - static_cast(target_arch::gfx1201), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = float -template -struct default_reduce_config< - static_cast(target_arch::gfx1201), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::half -template -struct default_reduce_config(target_arch::gfx1201), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int64_t -template -struct default_reduce_config< - static_cast(target_arch::gfx1201), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int -template -struct default_reduce_config< - static_cast(target_arch::gfx1201), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = short -template -struct default_reduce_config< - static_cast(target_arch::gfx1201), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int8_t -template -struct default_reduce_config(target_arch::gfx1201), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = double -template -struct default_reduce_config< - static_cast(target_arch::gfx906), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<128, 4, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = float -template -struct default_reduce_config< - static_cast(target_arch::gfx906), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 4, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::half -template -struct default_reduce_config(target_arch::gfx906), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_reduce_config< - static_cast(target_arch::gfx906), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_config<128, 2, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int64_t -template -struct default_reduce_config< - static_cast(target_arch::gfx906), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 2, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int -template -struct default_reduce_config< - static_cast(target_arch::gfx906), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 4, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = short -template -struct default_reduce_config< - static_cast(target_arch::gfx906), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int8_t -template -struct default_reduce_config(target_arch::gfx906), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_config<128, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = double -template -struct default_reduce_config< - static_cast(target_arch::gfx908), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = float -template -struct default_reduce_config< - static_cast(target_arch::gfx908), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 4, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::half -template -struct default_reduce_config(target_arch::gfx908), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_config<128, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_reduce_config< - static_cast(target_arch::gfx908), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_config<128, 2, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int64_t -template -struct default_reduce_config< - static_cast(target_arch::gfx908), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int -template -struct default_reduce_config< - static_cast(target_arch::gfx908), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 4, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = short -template -struct default_reduce_config< - static_cast(target_arch::gfx908), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_config<128, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int8_t -template -struct default_reduce_config(target_arch::gfx908), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = double -template -struct default_reduce_config< - static_cast(target_arch::gfx90a), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = float -template -struct default_reduce_config< - static_cast(target_arch::gfx90a), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<128, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::half -template -struct default_reduce_config(target_arch::gfx90a), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_config<64, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_reduce_config< - static_cast(target_arch::gfx90a), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_config<256, 2, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int64_t -template -struct default_reduce_config< - static_cast(target_arch::gfx90a), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 4, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int -template -struct default_reduce_config< - static_cast(target_arch::gfx90a), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<128, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = short -template -struct default_reduce_config< - static_cast(target_arch::gfx90a), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_config<64, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int8_t -template -struct default_reduce_config(target_arch::gfx90a), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_config<128, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = double -template -struct default_reduce_config< - static_cast(target_arch::unknown), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = float -template -struct default_reduce_config< - static_cast(target_arch::unknown), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 4, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::half -template -struct default_reduce_config(target_arch::unknown), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_config<128, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_reduce_config< - static_cast(target_arch::unknown), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_config<128, 2, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int64_t -template -struct default_reduce_config< - static_cast(target_arch::unknown), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int -template -struct default_reduce_config< - static_cast(target_arch::unknown), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 4, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = short -template -struct default_reduce_config< - static_cast(target_arch::unknown), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_config<128, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int8_t -template -struct default_reduce_config(target_arch::unknown), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = double -template -struct default_reduce_config< - static_cast(target_arch::gfx942), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = float -template -struct default_reduce_config< - static_cast(target_arch::gfx942), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<128, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::half -template -struct default_reduce_config(target_arch::gfx942), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_reduce_config< - static_cast(target_arch::gfx942), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int64_t -template -struct default_reduce_config< - static_cast(target_arch::gfx942), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 2, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int -template -struct default_reduce_config< - static_cast(target_arch::gfx942), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = short -template -struct default_reduce_config< - static_cast(target_arch::gfx942), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int8_t -template -struct default_reduce_config(target_arch::gfx942), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_reduce_config< - static_cast(target_arch::gfx1201), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_config<256, 4, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; +template +constexpr auto reduce_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 2}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 2}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_config_params{ + {256, 4}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return reduce_config_params{ + {256, 2}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 1}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 2}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_config_params{ + {256, 4}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_config_params{ + {256, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Default case if none of the conditions match + return reduce_config_params_base(); +} + +template +constexpr auto reduce_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 4}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 4}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return reduce_config_params{ + {256, 4}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 4}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_config_params{ + {128, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Default case if none of the conditions match + return reduce_config_params_base(); +} + +template +constexpr auto reduce_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Default case if none of the conditions match + return reduce_config_params_base(); +} + +template +constexpr auto reduce_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return reduce_config_params{ + {256, 4}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Default case if none of the conditions match + return reduce_config_params_base(); +} + +template +constexpr auto reduce_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {128, 4}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 4}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_config_params{ + {256, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return reduce_config_params{ + {128, 2}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 2}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 4}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_config_params{ + {256, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_config_params{ + {128, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Default case if none of the conditions match + return reduce_config_params_base(); +} + +template +constexpr auto reduce_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 4}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_config_params{ + {128, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return reduce_config_params{ + {128, 2}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 4}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_config_params{ + {128, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Default case if none of the conditions match + return reduce_config_params_base(); +} + +template +constexpr auto reduce_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {128, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_config_params{ + {64, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return reduce_config_params{ + {256, 2}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 4}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {128, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_config_params{ + {64, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_config_params{ + {128, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Default case if none of the conditions match + return reduce_config_params_base(); +} + +template +constexpr auto reduce_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {128, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 2}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Default case if none of the conditions match + return reduce_config_params_base(); +} + +template +constexpr auto reduce_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_config_params> +{ + return reduce_config_picker< + comp_target, + key_type>(); +} + +// All the existing configs should be auto generated +using reduce_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_reduce_by_key.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_reduce_by_key.hpp index 7c9a06431bc..6b346e8017c 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_reduce_by_key.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_reduce_by_key.hpp @@ -38,5561 +38,3752 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_reduce_by_key_config : default_reduce_by_key_config_base::type -{}; - -// Based on key_type = int64_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 8, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 4, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 8, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan, - 2> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 4, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<512, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 8, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<192, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<192, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 8, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<192, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<192, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<512, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<192, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<384, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<192, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<384, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<384, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<384, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<192, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<384, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<384, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<384, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<384, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 8, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan, - 2> -{}; - -// Based on key_type = int64_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<512, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan, - 2> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan, - 2> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 8, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan, - 2> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<512, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan, - 2> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<384, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan, - 2> -{}; - -// Based on key_type = int, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan, - 2> -{}; - -// Based on key_type = int, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan, - 2> -{}; - -// Based on key_type = int, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<384, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<384, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan, - 2> -{}; - -// Based on key_type = short, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<384, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<512, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<192, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 6, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 4, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<512, - 6, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 11, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 4, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<512, - 6, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 8, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 8, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<512, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 8, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 8, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<512, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 9, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<384, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<384, - 8, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<384, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<384, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<192, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<384, - 8, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<512, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<512, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<384, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 13, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 13, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<384, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<384, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<384, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<384, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<128, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<384, - 8, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<384, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<128, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<128, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<128, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<128, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<128, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<192, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<192, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<192, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<192, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<192, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<192, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<192, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<192, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<192, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<192, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<192, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<192, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<256, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<192, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<256, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<256, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<192, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<192, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<192, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<192, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<192, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<192, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<192, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<192, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<192, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<192, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<192, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<192, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<192, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<192, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<192, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<192, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<192, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<192, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<384, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<512, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<384, - 8, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<384, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<512, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<512, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<512, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<384, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<512, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<512, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<192, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<192, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<128, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<192, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<192, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<192, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<128, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<256, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<256, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<256, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 6, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 4, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<512, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<512, - 4, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<512, - 6, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<512, - 4, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<512, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<512, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<512, - 6, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<512, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<192, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<192, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<256, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<256, - 4, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<256, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<256, - 12, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<256, - 8, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<128, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<256, - 8, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<128, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<192, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<256, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<192, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<128, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<192, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<192, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<256, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<192, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<192, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<192, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<256, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<128, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<192, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<192, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<128, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<128, - 6, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<256, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<256, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<192, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<192, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<192, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<192, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = double -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<128, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = float -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<192, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::half -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : reduce_by_key_config<192, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<128, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : reduce_by_key_config<128, - 6, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : reduce_by_key_config<256, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : reduce_by_key_config<256, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : reduce_by_key_config<256, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<192, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<192, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<192, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_reduce_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<((sizeof(key_type) <= 1) - && !bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : reduce_by_key_config<192, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; +template +constexpr auto reduce_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_by_key_config_params> +{ + // Based on key_type = int64_t, value_type = double + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 8}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = float + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 4}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 8}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan, + 2 + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 4}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {512, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = double + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 8}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = float + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {192, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {192, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 8}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {192, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = short + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {192, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {512, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = double + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = float + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {192, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {384, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {192, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = short + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {384, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {384, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = double + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {384, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = float + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {192, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {384, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {384, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {384, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {384, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = double + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = float + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {256, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {256, 4}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {256, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {256, 8}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {128, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {256, 8}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return reduce_by_key_config_params_base(); +} + +template +constexpr auto reduce_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_by_key_config_params> +{ + // Based on key_type = int64_t, value_type = double + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = float + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 8}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan, + 2 + }; + } + // Based on key_type = int64_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {512, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan, + 2 + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan, + 2 + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 8}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan, + 2 + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {512, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan, + 2 + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {384, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan, + 2 + }; + } + // Based on key_type = int, value_type = double + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = float + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan, + 2 + }; + } + // Based on key_type = int, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan, + 2 + }; + } + // Based on key_type = int, value_type = int + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = short + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = double + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {384, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = float + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {384, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan, + 2 + }; + } + // Based on key_type = short, value_type = int + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = short + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {384, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {512, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = double + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = float + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {192, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = double + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 6}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = float + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 4}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {512, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {512, 4}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 6}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 4}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {512, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {512, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {512, 6}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {512, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {192, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {192, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return reduce_by_key_config_params_base(); +} + +template +constexpr auto reduce_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_by_key_config_params> +{ + // Based on key_type = int64_t, value_type = double + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 6}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = float + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 4}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {512, 6}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 11}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 4}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {512, 6}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = double + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 8}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = float + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 8}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {512, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 8}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 8}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = short + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {512, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = double + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 9}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = float + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {384, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {384, 8}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = short + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {384, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {384, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = double + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {192, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = float + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {384, 8}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return reduce_by_key_config_params_base(); +} + +template +constexpr auto reduce_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_by_key_config_params> +{ + // Based on key_type = int64_t, value_type = double + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = float + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {512, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {512, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {384, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = double + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = float + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 13}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 13}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = short + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = double + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {384, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = float + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {384, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {384, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = short + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {384, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {128, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = double + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = float + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {384, 8}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {384, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = double + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {384, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = float + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {512, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {384, 8}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {384, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {512, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {512, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {512, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {384, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {512, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {512, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return reduce_by_key_config_params_base(); +} + +template +constexpr auto reduce_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_by_key_config_params> +{ + // Based on key_type = int64_t, value_type = double + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {128, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = float + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {128, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {128, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {128, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {128, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {192, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {192, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = double + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = float + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = short + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = double + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = float + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = short + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = double + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = float + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {192, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {192, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {192, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {192, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = double + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = float + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {192, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {192, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {128, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {192, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {192, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {192, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {128, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {256, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {256, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {256, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return reduce_by_key_config_params_base(); +} + +template +constexpr auto reduce_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_by_key_config_params> +{ + // Based on key_type = int64_t, value_type = double + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = float + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = double + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = float + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = short + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = double + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = float + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = short + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = double + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = float + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {192, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = double + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {128, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = float + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {192, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {192, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {128, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {128, 6}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {256, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {256, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {192, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {192, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {192, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {192, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return reduce_by_key_config_params_base(); +} + +template +constexpr auto reduce_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_by_key_config_params> +{ + // Based on key_type = int64_t, value_type = double + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {192, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = float + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {192, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {192, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = double + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {192, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = float + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {256, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {192, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = short + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {256, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {256, 12}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = double + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {192, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = float + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {192, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {192, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {192, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {192, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = short + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {192, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {192, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = double + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {192, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = float + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {192, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {192, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {192, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {192, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = double + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {128, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = float + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {192, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {256, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {192, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {128, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {192, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {192, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {256, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {192, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {192, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {192, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {256, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return reduce_by_key_config_params_base(); +} + +template +constexpr auto reduce_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_by_key_config_params> +{ + // Based on key_type = rocprim::int128_t, value_type = double + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = float + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 16) && (sizeof(key_type) > 8) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = double + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = float + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {192, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 8) && (sizeof(key_type) > 4) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = double + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = float + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = short + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr(((sizeof(key_type) <= 4) && (sizeof(key_type) > 2) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = double + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = float + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {192, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = short + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {192, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr(((sizeof(key_type) <= 2) && (sizeof(key_type) > 1) + && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {192, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = double + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = float + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = rocprim::half + if constexpr(((sizeof(key_type) <= 1) && bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return reduce_by_key_config_params{ + {192, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr(((sizeof(key_type) <= 1) && !bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return reduce_by_key_config_params{ + {192, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return reduce_by_key_config_params_base(); +} + +template +constexpr auto reduce_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_by_key_config_params> +{ + return reduce_by_key_config_picker< + comp_target, + key_type, + value_type>(); +} + +// All the existing configs should be auto generated +using reduce_by_key_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_run_length_encode.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_run_length_encode.hpp index 0f8c32c4cee..19f14d24b31 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_run_length_encode.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_run_length_encode.hpp @@ -38,1075 +38,786 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_trivial_runs_config : default_reduce_by_key_config_base::type -{}; - -// Based on key_type = double -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_by_key_config<512, - 8, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_by_key_config<192, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_by_key_config<384, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_by_key_config<256, - 8, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_by_key_config<512, - 8, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_by_key_config<192, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_by_key_config<384, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_by_key_config<384, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_by_key_config<512, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_by_key_config<384, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_by_key_config<512, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_by_key_config<512, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_by_key_config<512, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_by_key_config<384, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_by_key_config<512, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_by_key_config<384, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_by_key_config<512, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_by_key_config<512, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_by_key_config<384, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_by_key_config<192, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_by_key_config<512, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_by_key_config<512, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double -template -struct default_trivial_runs_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float -template -struct default_trivial_runs_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half -template -struct default_trivial_runs_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_by_key_config<512, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_trivial_runs_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t -template -struct default_trivial_runs_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int -template -struct default_trivial_runs_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short -template -struct default_trivial_runs_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_by_key_config<512, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t -template -struct default_trivial_runs_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_by_key_config<256, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_by_key_config<192, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_by_key_config<192, - 7, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_by_key_config<256, - 10, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_by_key_config<192, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_by_key_config<512, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_by_key_config<512, - 14, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_by_key_config<512, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_trivial_runs_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_by_key_config<512, - 5, - block_load_method::block_load_transpose, - block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> -{}; +template +constexpr auto run_length_encode_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_by_key_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 8}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_by_key_config_params{ + {192, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_by_key_config_params{ + {384, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return reduce_by_key_config_params{ + {256, 8}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 8}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_by_key_config_params{ + {192, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_by_key_config_params{ + {384, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_by_key_config_params{ + {384, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return reduce_by_key_config_params_base(); +} + +template +constexpr auto run_length_encode_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_by_key_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_by_key_config_params{ + {512, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return reduce_by_key_config_params{ + {384, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_by_key_config_params{ + {512, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return reduce_by_key_config_params_base(); +} + +template +constexpr auto run_length_encode_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_by_key_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_by_key_config_params{ + {384, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_by_key_config_params{ + {384, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return reduce_by_key_config_params_base(); +} + +template +constexpr auto run_length_encode_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_by_key_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_by_key_config_params{ + {512, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_by_key_config_params{ + {512, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_by_key_config_params{ + {384, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return reduce_by_key_config_params{ + {512, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return reduce_by_key_config_params_base(); +} + +template +constexpr auto run_length_encode_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_by_key_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return reduce_by_key_config_params{ + {192, 5}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return reduce_by_key_config_params_base(); +} + +template +constexpr auto run_length_encode_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_by_key_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_by_key_config_params{ + {512, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_by_key_config_params{ + {512, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_by_key_config_params{ + {256, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return reduce_by_key_config_params_base(); +} + +template +constexpr auto run_length_encode_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_by_key_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_by_key_config_params{ + {192, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return reduce_by_key_config_params{ + {192, 7}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_by_key_config_params{ + {256, 10}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_by_key_config_params{ + {192, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return reduce_by_key_config_params_base(); +} + +template +constexpr auto run_length_encode_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_by_key_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_by_key_config_params{ + {512, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_by_key_config_params{ + {512, 14}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_by_key_config_params{ + {512, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_direct, + block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return reduce_by_key_config_params_base(); +} + +template +constexpr auto run_length_encode_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_by_key_config_params> +{ + return run_length_encode_config_picker< + comp_target, + key_type, + value_type>(); +} + +// All the existing configs should be auto generated +using run_length_encode_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_run_length_encode_non_trivial.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_run_length_encode_non_trivial.hpp index e4042269f6c..731b6571442 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_run_length_encode_non_trivial.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_run_length_encode_non_trivial.hpp @@ -40,933 +40,722 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_non_trivial_runs_config : default_non_trivial_runs_config_base::type -{}; - -// Based on key_type = double -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1030), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1030), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1030), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : non_trivial_runs_config<128, - 32, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1030), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : non_trivial_runs_config<64, - 16, - ::rocprim::block_load_method::block_load_warp_transpose, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1030), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1030), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1030), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : non_trivial_runs_config<128, - 32, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1030), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : non_trivial_runs_config<128, - 32, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1100), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1100), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : non_trivial_runs_config<128, - 32, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1100), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : non_trivial_runs_config<128, - 32, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1100), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : non_trivial_runs_config<64, - 8, - ::rocprim::block_load_method::block_load_warp_transpose, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1100), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1100), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : non_trivial_runs_config<128, - 32, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1100), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : non_trivial_runs_config<128, - 32, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1100), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : non_trivial_runs_config<128, - 32, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1200), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : non_trivial_runs_config<256, - 8, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1200), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : non_trivial_runs_config<256, - 8, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1200), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : non_trivial_runs_config<128, - 32, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1200), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : non_trivial_runs_config<256, - 8, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1200), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : non_trivial_runs_config<256, - 8, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1200), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : non_trivial_runs_config<128, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1200), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : non_trivial_runs_config<128, - 32, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1201), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : non_trivial_runs_config<512, - 8, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1201), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1201), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : non_trivial_runs_config<128, - 32, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1201), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : non_trivial_runs_config<512, - 8, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1201), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1201), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1201), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : non_trivial_runs_config<128, - 32, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx906), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : non_trivial_runs_config<256, - 8, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx906), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : non_trivial_runs_config<256, - 8, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx906), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : non_trivial_runs_config<256, - 8, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx906), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : non_trivial_runs_config<64, - 8, - ::rocprim::block_load_method::block_load_warp_transpose, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx906), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : non_trivial_runs_config<256, - 8, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx906), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : non_trivial_runs_config<256, - 8, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx906), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : non_trivial_runs_config<128, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx906), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : non_trivial_runs_config<256, - 8, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx908), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : non_trivial_runs_config<64, - 16, - ::rocprim::block_load_method::block_load_warp_transpose, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx908), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx908), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : non_trivial_runs_config<64, - 64, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx908), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : non_trivial_runs_config<64, - 16, - ::rocprim::block_load_method::block_load_warp_transpose, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx908), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : non_trivial_runs_config<64, - 16, - ::rocprim::block_load_method::block_load_warp_transpose, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx908), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx908), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx908), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx90a), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : non_trivial_runs_config<256, - 8, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx90a), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : non_trivial_runs_config<256, - 8, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx90a), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx90a), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : non_trivial_runs_config<64, - 8, - ::rocprim::block_load_method::block_load_warp_transpose, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx90a), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : non_trivial_runs_config<256, - 8, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx90a), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : non_trivial_runs_config<256, - 8, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx90a), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx90a), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double -template -struct default_non_trivial_runs_config< - static_cast(target_arch::unknown), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : non_trivial_runs_config<64, - 16, - ::rocprim::block_load_method::block_load_warp_transpose, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float -template -struct default_non_trivial_runs_config< - static_cast(target_arch::unknown), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half -template -struct default_non_trivial_runs_config< - static_cast(target_arch::unknown), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : non_trivial_runs_config<64, - 64, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::unknown), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : non_trivial_runs_config<64, - 16, - ::rocprim::block_load_method::block_load_warp_transpose, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::unknown), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : non_trivial_runs_config<64, - 16, - ::rocprim::block_load_method::block_load_warp_transpose, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int -template -struct default_non_trivial_runs_config< - static_cast(target_arch::unknown), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short -template -struct default_non_trivial_runs_config< - static_cast(target_arch::unknown), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::unknown), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = double -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx942), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : non_trivial_runs_config<512, - 8, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = float -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx942), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::half -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx942), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx942), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : non_trivial_runs_config<256, - 8, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int64_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx942), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : non_trivial_runs_config<512, - 8, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx942), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = short -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx942), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = int8_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx942), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : non_trivial_runs_config<256, - 16, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_non_trivial_runs_config< - static_cast(target_arch::gfx1201), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : non_trivial_runs_config<256, - 8, - ::rocprim::block_load_method::block_load_vectorize, - ::rocprim::block_scan_algorithm::using_warp_scan> -{}; +template +constexpr auto run_length_encode_non_trivial_config_picker() -> std::enable_if_t< + std::is_same>::value, + non_trivial_runs_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return non_trivial_runs_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return non_trivial_runs_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return non_trivial_runs_config_params{ + {128, 32}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return non_trivial_runs_config_params{ + {64, 16}, + ::rocprim::block_load_method::block_load_warp_transpose, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return non_trivial_runs_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return non_trivial_runs_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return non_trivial_runs_config_params{ + {128, 32}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return non_trivial_runs_config_params{ + {128, 32}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return non_trivial_runs_config_params_base(); +} + +template +constexpr auto run_length_encode_non_trivial_config_picker() -> std::enable_if_t< + std::is_same>::value, + non_trivial_runs_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return non_trivial_runs_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return non_trivial_runs_config_params{ + {128, 32}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return non_trivial_runs_config_params{ + {128, 32}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return non_trivial_runs_config_params{ + {64, 8}, + ::rocprim::block_load_method::block_load_warp_transpose, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return non_trivial_runs_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return non_trivial_runs_config_params{ + {128, 32}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return non_trivial_runs_config_params{ + {128, 32}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return non_trivial_runs_config_params{ + {128, 32}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return non_trivial_runs_config_params_base(); +} + +template +constexpr auto run_length_encode_non_trivial_config_picker() -> std::enable_if_t< + std::is_same>::value, + non_trivial_runs_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return non_trivial_runs_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return non_trivial_runs_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return non_trivial_runs_config_params{ + {128, 32}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return non_trivial_runs_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return non_trivial_runs_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return non_trivial_runs_config_params{ + {128, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return non_trivial_runs_config_params{ + {128, 32}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return non_trivial_runs_config_params_base(); +} + +template +constexpr auto run_length_encode_non_trivial_config_picker() -> std::enable_if_t< + std::is_same>::value, + non_trivial_runs_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return non_trivial_runs_config_params{ + {512, 8}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return non_trivial_runs_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return non_trivial_runs_config_params{ + {128, 32}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return non_trivial_runs_config_params{ + {512, 8}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return non_trivial_runs_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return non_trivial_runs_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return non_trivial_runs_config_params{ + {128, 32}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return non_trivial_runs_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return non_trivial_runs_config_params_base(); +} + +template +constexpr auto run_length_encode_non_trivial_config_picker() -> std::enable_if_t< + std::is_same>::value, + non_trivial_runs_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return non_trivial_runs_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return non_trivial_runs_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return non_trivial_runs_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return non_trivial_runs_config_params{ + {64, 8}, + ::rocprim::block_load_method::block_load_warp_transpose, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return non_trivial_runs_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return non_trivial_runs_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return non_trivial_runs_config_params{ + {128, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return non_trivial_runs_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return non_trivial_runs_config_params_base(); +} + +template +constexpr auto run_length_encode_non_trivial_config_picker() -> std::enable_if_t< + std::is_same>::value, + non_trivial_runs_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return non_trivial_runs_config_params{ + {64, 16}, + ::rocprim::block_load_method::block_load_warp_transpose, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return non_trivial_runs_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return non_trivial_runs_config_params{ + {64, 64}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return non_trivial_runs_config_params{ + {64, 16}, + ::rocprim::block_load_method::block_load_warp_transpose, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return non_trivial_runs_config_params{ + {64, 16}, + ::rocprim::block_load_method::block_load_warp_transpose, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return non_trivial_runs_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return non_trivial_runs_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return non_trivial_runs_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return non_trivial_runs_config_params_base(); +} + +template +constexpr auto run_length_encode_non_trivial_config_picker() -> std::enable_if_t< + std::is_same>::value, + non_trivial_runs_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return non_trivial_runs_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return non_trivial_runs_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return non_trivial_runs_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return non_trivial_runs_config_params{ + {64, 8}, + ::rocprim::block_load_method::block_load_warp_transpose, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return non_trivial_runs_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return non_trivial_runs_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return non_trivial_runs_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return non_trivial_runs_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return non_trivial_runs_config_params_base(); +} + +template +constexpr auto run_length_encode_non_trivial_config_picker() -> std::enable_if_t< + std::is_same>::value, + non_trivial_runs_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return non_trivial_runs_config_params{ + {512, 8}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return non_trivial_runs_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return non_trivial_runs_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return non_trivial_runs_config_params{ + {256, 8}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return non_trivial_runs_config_params{ + {512, 8}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return non_trivial_runs_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return non_trivial_runs_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return non_trivial_runs_config_params{ + {256, 16}, + ::rocprim::block_load_method::block_load_vectorize, + ::rocprim::block_scan_algorithm::using_warp_scan + }; + } + // Default case if none of the conditions match + return non_trivial_runs_config_params_base(); +} + +template +constexpr auto run_length_encode_non_trivial_config_picker() -> std::enable_if_t< + std::is_same>::value, + non_trivial_runs_config_params> +{ + return run_length_encode_non_trivial_config_picker< + comp_target, + key_type>(); +} + +// All the existing configs should be auto generated +using run_length_encode_non_trivial_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_segmented_reduce.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_segmented_reduce.hpp index b04326b6ac8..9f6b159fa05 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_segmented_reduce.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_segmented_reduce.hpp @@ -38,650 +38,587 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_segmented_reduce_config : default_reduce_config_base::type -{}; - -// Based on key_type = double -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1030), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = float -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1030), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::half -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1030), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1030), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int64_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1030), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1030), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = short -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1030), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int8_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1030), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = double -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1100), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = float -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1100), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::half -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1100), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1100), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int64_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1100), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1100), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = short -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1100), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int8_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1100), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = double -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx906), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = float -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx906), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::half -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx906), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx906), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int64_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx906), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx906), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = short -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx906), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int8_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx906), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = double -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx908), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = float -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx908), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::half -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx908), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx908), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_config<128, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int64_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx908), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx908), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = short -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx908), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int8_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx908), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = double -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx90a), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = float -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx90a), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::half -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx90a), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx90a), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int64_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx90a), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx90a), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = short -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx90a), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int8_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx90a), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = double -template -struct default_segmented_reduce_config< - static_cast(target_arch::unknown), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = float -template -struct default_segmented_reduce_config< - static_cast(target_arch::unknown), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::half -template -struct default_segmented_reduce_config< - static_cast(target_arch::unknown), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::unknown), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_config<128, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int64_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::unknown), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int -template -struct default_segmented_reduce_config< - static_cast(target_arch::unknown), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = short -template -struct default_segmented_reduce_config< - static_cast(target_arch::unknown), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int8_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::unknown), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = double -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1201), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = float -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1201), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::half -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1201), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1201), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int64_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1201), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1201), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = short -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1201), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int8_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx1201), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = double -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx942), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = float -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx942), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::half -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx942), - key_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 2))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = rocprim::int128_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx942), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int64_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx942), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4))>> - : reduce_config<256, 16, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx942), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = short -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx942), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; - -// Based on key_type = int8_t -template -struct default_segmented_reduce_config< - static_cast(target_arch::gfx942), - key_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(key_type) <= 1))>> - : reduce_config<256, 8, ::rocprim::block_reduce_algorithm::using_warp_reduce> -{}; +template +constexpr auto segmented_reduce_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_config_params{ + {256, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_config_params{ + {256, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Default case if none of the conditions match + return reduce_config_params_base(); +} + +template +constexpr auto segmented_reduce_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_config_params{ + {256, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Default case if none of the conditions match + return reduce_config_params_base(); +} + +template +constexpr auto segmented_reduce_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_config_params{ + {256, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Default case if none of the conditions match + return reduce_config_params_base(); +} + +template +constexpr auto segmented_reduce_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return reduce_config_params{ + {128, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_config_params{ + {256, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Default case if none of the conditions match + return reduce_config_params_base(); +} + +template +constexpr auto segmented_reduce_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return reduce_config_params{ + {256, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_config_params{ + {256, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_config_params{ + {256, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Default case if none of the conditions match + return reduce_config_params_base(); +} + +template +constexpr auto segmented_reduce_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_config_params{ + {256, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Default case if none of the conditions match + return reduce_config_params_base(); +} + +template +constexpr auto segmented_reduce_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_config_params> +{ + // Based on key_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4))) + { + return reduce_config_params{ + {256, 16}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2))) + { + return reduce_config_params{ + {256, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1))) + { + return reduce_config_params{ + {256, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Based on key_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1))) + { + return reduce_config_params{ + {256, 8}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; + } + // Default case if none of the conditions match + return reduce_config_params_base(); +} + +template +constexpr auto segmented_reduce_config_picker() -> std::enable_if_t< + std::is_same>::value, + reduce_config_params> +{ + return segmented_reduce_config_picker< + comp_target, + key_type>(); +} + +// All the existing configs should be auto generated +using segmented_reduce_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/device_config_helper.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/device_config_helper.hpp index dd69d7a44a4..e37f7232dce 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/device_config_helper.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/device_config_helper.hpp @@ -116,20 +116,21 @@ struct merge_sort_block_sort_config_base // No radix_sort_block_sort_params and radix_sort_block_sort_config exist since the only // configuration member is a kernel_config. template -struct radix_sort_block_sort_config_base +constexpr kernel_config_params radix_sort_block_sort_config_params_base() { - static constexpr unsigned int item_scale = ::rocprim::max(sizeof(Key), sizeof(Value)); + constexpr unsigned int item_scale = ::rocprim::max(sizeof(Key), sizeof(Value)); // multiply by 2 to ensure block_sort's items_per_block >= block_merge's items_per_block - static constexpr unsigned int block_size = merge_sort_block_size(item_scale) * 2; - static constexpr unsigned int items_per_thread + constexpr unsigned int block_size = merge_sort_block_size(item_scale) * 2; + constexpr unsigned int items_per_thread = rocprim::min(4u, merge_sort_items_per_thread(item_scale)); - using type = kernel_config; // The items per block should be a power of two, as this is a requirement for the // radix sort merge sort. static_assert(is_power_of_two(block_size * items_per_thread), "Sorted items per block should be a power of two."); + + return kernel_config_params{block_size, items_per_thread}; }; /// \brief Default values are provided by \p merge_sort_block_merge_config_base. @@ -275,14 +276,16 @@ namespace detail { template -struct default_reduce_config_base +constexpr reduce_config_params reduce_config_params_base() { - static constexpr unsigned int item_scale + constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div(sizeof(Value), sizeof(int)); - using type = reduce_config::value, - ::rocprim::max(1u, 16u / item_scale), - ::rocprim::block_reduce_algorithm::using_warp_reduce>; + return reduce_config_params{ + {limit_block_size<256U, sizeof(Value), ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 16u / item_scale)}, + ::rocprim::block_reduce_algorithm::using_warp_reduce + }; }; struct scan_config_tag @@ -1130,10 +1133,10 @@ struct default_partition_config_base struct reduce_by_key_config_params { kernel_config_params kernel_config; - unsigned int tiles_per_block; block_load_method load_keys_method; block_load_method load_values_method; block_scan_algorithm scan_algorithm; + unsigned int tiles_per_block = 1; }; } // namespace detail @@ -1182,10 +1185,10 @@ struct reduce_by_key_config : public detail::reduce_by_key_config_params constexpr reduce_by_key_config() : detail::reduce_by_key_config_params{ {BlockSize, ItemsPerThread, SizeLimit}, - TilesPerBlock, LoadKeysMethod, LoadValuesMethod, - ScanAlgorithm + ScanAlgorithm, + TilesPerBlock } {}; #endif }; @@ -1194,32 +1197,36 @@ namespace detail { template -struct default_reduce_by_key_config_base -{ - using small_config = reduce_by_key_config<256, - 15, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan, - sizeof(Value) < 16 ? 1 : 2>; +constexpr reduce_by_key_config_params reduce_by_key_config_params_base() +{ + constexpr auto small_config = reduce_by_key_config_params{ + {256, 15}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan, + sizeof(Value) < 16 ? 1 : 2, + }; - static constexpr unsigned int size_memory_per_item = std::max(sizeof(Key), sizeof(Value)); - static constexpr unsigned int item_scale - = static_cast(ceiling_div(size_memory_per_item, 2 * sizeof(int))); - static constexpr unsigned int items_per_thread = std::max(1u, 15u / item_scale); + if constexpr(std::max(sizeof(Key), sizeof(Value)) <= 16) + { + return small_config; + } - using large_config - = reduce_by_key_config::value, - items_per_thread, - block_load_method::block_load_transpose, - block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan, - 2>; + constexpr unsigned int size_memory_per_item = std::max(sizeof(Key), sizeof(Value)); + constexpr unsigned int item_scale + = static_cast(ceiling_div(size_memory_per_item, 2 * sizeof(int))); + constexpr unsigned int items_per_thread = std::max(1u, 15u / item_scale); + + constexpr auto large_config = reduce_by_key_config_params{ + {limit_block_size<256U, items_per_thread * size_memory_per_item, ROCPRIM_WARP_SIZE_64>:: + value, items_per_thread}, + block_load_method::block_load_transpose, + block_load_method::block_load_transpose, + block_scan_algorithm::using_warp_scan, + 2 + }; - using type = std:: - conditional_t; + return large_config; }; } // namespace detail @@ -1315,35 +1322,41 @@ namespace detail { template -struct default_non_trivial_runs_config_base +constexpr non_trivial_runs_config_params non_trivial_runs_config_params_base() { - static constexpr unsigned int items_per_thread = 16; - using small_config = non_trivial_runs_config<256U, - items_per_thread, - block_load_method::block_load_vectorize, - block_scan_algorithm::reduce_then_scan>; + constexpr unsigned int items_per_thread = 16; + constexpr auto small_config = non_trivial_runs_config_params{ + {256U, items_per_thread}, + block_load_method::block_load_vectorize, + block_scan_algorithm::reduce_then_scan + }; + + if constexpr(sizeof(InputT) < 8) + { + return small_config; + } using OffsetCountPairT = ::rocprim::tuple; - static constexpr unsigned int size_memory_per_item + constexpr unsigned int size_memory_per_item = std::max(sizeof(InputT), sizeof(OffsetCountPairT)); // Additional shared memory is required by the lookback scan state. - static constexpr unsigned int shared_mem_offset + constexpr unsigned int shared_mem_offset = sizeof(typename offset_lookback_scan_prefix_op< OffsetCountPairT, lookback_scan_state>::storage_type); - using big_config - = non_trivial_runs_config::value, - items_per_thread, - block_load_method::block_load_warp_transpose, - block_scan_algorithm::reduce_then_scan>; + constexpr auto big_config = non_trivial_runs_config_params{ + {detail::limit_block_size<64U, + items_per_thread * size_memory_per_item, + ROCPRIM_WARP_SIZE_64, shared_mem_offset>::value, + items_per_thread}, + block_load_method::block_load_warp_transpose, + block_scan_algorithm::reduce_then_scan + }; - using type = std::conditional_t; + return big_config; }; struct find_first_of_config_params diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_radix_sort.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_radix_sort.hpp index a4d993087c0..16e1ca69d69 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_radix_sort.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_radix_sort.hpp @@ -610,11 +610,11 @@ hipError_t constexpr bool use_default_small_block_sort = is_default_config || std::is_same::value; - using default_radix_sort_block_sort_config = - typename radix_sort_block_sort_config_base::type; + constexpr auto default_radix_sort_block_sort_config + = radix_sort_block_sort_config_params_base(); using default_block_sort_config - = kernel_config; + = kernel_config; using block_sort_config = typename std::conditional::type; diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_radix_sort_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_radix_sort_config.hpp index ed69e8479b3..d30d20c5afc 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_radix_sort_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_radix_sort_config.hpp @@ -94,41 +94,26 @@ struct radix_sort_onesweep_sort_config_static_selector .radix_sort_onesweep_config_params::sort.block_size; }; -// Sub-algorithm block_sort: -template -struct wrapped_radix_sort_block_sort_config +template +struct radix_sort_block_sort_config_selector { - template - struct architecture_config - { - static constexpr kernel_config_params params = RadixSortBlockSortConfig(); - }; + using targets = radix_sort_block_sort_targets; + using param_type = kernel_config_params; + + param_type params; + + template + constexpr radix_sort_block_sort_config_selector(Target) + : params(radix_sort_block_sort_config_picker()) + {} }; -template -struct wrapped_radix_sort_block_sort_config +template +struct radix_sort_block_sort_config_static_selector { - template - struct architecture_config - { - static constexpr kernel_config_params params - = default_radix_sort_block_sort_config(Arch), Key, Value>(); - }; + static constexpr auto block_size = target_config2::params.block_size; }; -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template -template -constexpr kernel_config_params - wrapped_radix_sort_block_sort_config::architecture_config< - Arch>::params; - -template -template -constexpr kernel_config_params wrapped_radix_sort_block_sort_config:: - architecture_config::params; -#endif // DOXYGEN_SHOULD_SKIP_THIS - } // namespace detail END_ROCPRIM_NAMESPACE diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_reduce.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_reduce.hpp index 02c1a7043c5..1368cd0216a 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_reduce.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_reduce.hpp @@ -80,12 +80,12 @@ namespace detail result_type>(input, size, output, initial_value, reduce_op); \ }; \ \ - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, \ - block_reduce_kernel, \ - dim3(1), \ - dim3(block_size), \ - 0, \ - stream)); \ + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(get_target, \ + block_reduce_kernel, \ + dim3(1), \ + dim3(block_size), \ + 0, \ + stream)); \ ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("block_reduce_kernel", size, start); \ } \ while(0) @@ -109,12 +109,17 @@ inline hipError_t reduce_impl(void* temporary_storage, using input_type = typename std::iterator_traits::value_type; using result_type = ::rocprim::accumulator_t; - using config = wrapped_reduce_config; + using Selector = reduce_config_selector; - detail::target_arch target_arch; + target_arch target_arch; ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); - const reduce_config_params params = dispatch_target_arch(target_arch); + gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + + const target get_target(target_arch, target_gpu); + + const auto params = get_config(Config{}, get_target); const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_thread = params.kernel_config.items_per_thread; @@ -196,12 +201,12 @@ inline hipError_t reduce_impl(void* temporary_storage, initial_value, reduce_op); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - block_reduce_kernel, - dim3(current_blocks), - dim3(block_size), - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(get_target, + block_reduce_kernel, + dim3(current_blocks), + dim3(block_size), + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("block_reduce_kernel", current_size, start); } diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_reduce_by_key.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_reduce_by_key.hpp index 666470a92fc..bf1d232cdb2 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_reduce_by_key.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_reduce_by_key.hpp @@ -97,7 +97,8 @@ ROCPRIM_KERNEL ROCPRIM_LAUNCH_BOUNDS(ROCPRIM_DEFAULT_MAX_BLOCK_SIZE) void } template(target_arch); + + gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + + const target get_target(target_arch, target_gpu); + + const auto params = get_config(Config{}, get_target); using scan_state_type = reduce_by_key::lookback_scan_state_t; @@ -268,12 +274,13 @@ hipError_t reduce_by_key_impl_wrapped_config(void* temporary i > 0 ? d_previous_accumulated : nullptr, ordered_bid); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - kernel, - dim3(number_of_blocks_launch), - dim3(block_size), - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR( + execute_launch_plan(get_target, + kernel, + dim3(number_of_blocks_launch), + dim3(block_size), + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("reduce_by_key_kernel", current_size, start); @@ -310,21 +317,21 @@ hipError_t reduce_by_key_impl(void* temporary_storage, { using key_type = ::rocprim::detail::value_type_t; using accumulator_type = reduce_by_key::accumulator_type_t; + using Selector = reduce_by_key_config_selector; - using config = wrapped_reduce_by_key_config; - - return detail::reduce_by_key_impl_wrapped_config(temporary_storage, - storage_size, - keys_input, - values_input, - size, - unique_output, - aggregates_output, - unique_count_output, - reduce_op, - key_compare_op, - stream, - debug_synchronous); + return detail::reduce_by_key_impl_wrapped_config( + temporary_storage, + storage_size, + keys_input, + values_input, + size, + unique_output, + aggregates_output, + unique_count_output, + reduce_op, + key_compare_op, + stream, + debug_synchronous); } } // namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_reduce_by_key_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_reduce_by_key_config.hpp index 9b028a64434..00b13a25601 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_reduce_by_key_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_reduce_by_key_config.hpp @@ -33,84 +33,34 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -// Generic for user-provided config: instantiate user-provided config. -template -struct wrapped_reduce_by_key_config +template +struct reduce_by_key_config_selector { - template - struct architecture_config - { - static constexpr reduce_by_key_config_params params = ReduceByKeyConfig{}; - }; -}; + using targets = reduce_by_key_targets; + using param_type = reduce_by_key_config_params; -// Generic for default config: instantiate base config. -template -struct wrapped_reduce_by_key_impl -{ - template - struct architecture_config - { - static constexpr reduce_by_key_config_params params = - typename default_reduce_by_key_config_base::type{}; - }; -}; + param_type params; -// Specialization for default config if types are not custom: instantiate the tuned config. -template -struct wrapped_reduce_by_key_impl< - KeyType, - AccumulatorType, - BinaryFunction, - std::enable_if_t::value && is_arithmetic::value - && is_binary_functional::value>> -{ - template - struct architecture_config + template + constexpr param_type picker_helper() { - static constexpr reduce_by_key_config_params params - = default_reduce_by_key_config(Arch), - KeyType, - AccumulatorType>{}; - }; + // Specialization for default config if types are not custom: instantiate the tuned config. + if constexpr(rocprim::is_arithmetic::value && rocprim::is_arithmetic::value + && rocprim::detail::is_binary_functional::value) + { + return reduce_by_key_config_picker(); + } + else + { + return reduce_by_key_config_params_base(); + } + } + + template + constexpr reduce_by_key_config_selector(Target) : params(picker_helper()) + {} }; -// Specialization for default config. -template -struct wrapped_reduce_by_key_config - : wrapped_reduce_by_key_impl -{}; - -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template -template -constexpr reduce_by_key_config_params - wrapped_reduce_by_key_config:: - architecture_config::params; - -template -template -constexpr reduce_by_key_config_params - wrapped_reduce_by_key_impl:: - architecture_config::params; - -template -template -constexpr reduce_by_key_config_params wrapped_reduce_by_key_impl< - KeyType, - AccumulatorType, - BinaryFunction, - std::enable_if_t::value && is_arithmetic::value - && is_binary_functional::value>>::architecture_config:: - params; -#endif // DOXYGEN_SHOULD_SKIP_THIS - } // end namespace detail END_ROCPRIM_NAMESPACE diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_reduce_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_reduce_config.hpp index 45f0464536f..c41acc8690a 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_reduce_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_reduce_config.hpp @@ -33,42 +33,19 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -template -struct wrapped_reduce_config +template +struct reduce_config_selector { - static_assert(std::is_same::value, - "Config must be a specialization of struct template reduce_config"); + using targets = reduce_targets; + using param_type = reduce_config_params; - template - struct architecture_config - { - static constexpr reduce_config_params params = ReduceConfig(); - }; -}; + param_type params; -template -struct wrapped_reduce_config -{ - template - struct architecture_config - { - static constexpr reduce_config_params params - = default_reduce_config(Arch), Value>(); - }; + template + constexpr reduce_config_selector(Target) : params(reduce_config_picker()) + {} }; -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template -template -constexpr reduce_config_params - wrapped_reduce_config::architecture_config::params; - -template -template -constexpr reduce_config_params - wrapped_reduce_config::architecture_config::params; -#endif // DOXYGEN_SHOULD_SKIP_THIS - } // namespace detail END_ROCPRIM_NAMESPACE diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_run_length_encode.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_run_length_encode.hpp index 6d6a4720878..949074a9472 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_run_length_encode.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_run_length_encode.hpp @@ -76,23 +76,23 @@ hipError_t run_length_encode_impl(void* temporary_storage, { using key_type = ::rocprim::detail::value_type_t; using accumulator_type = reduce_by_key::accumulator_type_t; - - using config = wrapped_trivial_runs_config; + using Selector = run_length_encode_config_selector; return detail::reduce_by_key_impl_wrapped_config< detail::lookback_scan_determinism::nondeterministic, - config>(temporary_storage, - storage_size, - keys_input, - values_input, - size, - unique_output, - aggregates_output, - unique_count_output, - reduce_op, - key_compare_op, - stream, - debug_synchronous); + typename select_reduce_by_key_config::type, + Selector>(temporary_storage, + storage_size, + keys_input, + values_input, + size, + unique_output, + aggregates_output, + unique_count_output, + reduce_op, + key_compare_op, + stream, + debug_synchronous); } template; // accumulator_type + // RLE config needs to be converted to non_trivial_runs_config. + using non_trivial_config = typename convert_to_non_trivial_config::type; + using Selector = run_length_encode_non_trivial_config_selector; bool use_atomic_block_id; ROCPRIM_RETURN_ON_ERROR(check_if_using_atomic_block_id(stream, use_atomic_block_id)); @@ -130,16 +133,19 @@ hipError_t run_length_encode_non_trivial_runs_impl(void* tempo ROCPRIM_RETURN_ON_ERROR(std::visit( [&](auto use_sleepy_scan, auto use_atomic_block_id) { - using config = rocprim::detail::wrapped_non_trivial_runs_config; + target_arch target_arch; + ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); + + gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + + const target get_target(target_arch, target_gpu); + + const auto params = get_config(non_trivial_config{}, get_target); using scan_state_type = ::rocprim::detail::lookback_scan_state; - detail::target_arch target_arch; - ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); - - const non_trivial_runs_config_params params - = dispatch_target_arch(target_arch); const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_block = block_size * params.kernel_config.items_per_thread; const std::size_t grid_size = detail::ceiling_div(size, items_per_block); @@ -221,12 +227,13 @@ hipError_t run_length_encode_non_trivial_runs_impl(void* tempo ordered_bid); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - non_trivial_kernel, - dim3(grid_size), - dim3(block_size), - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR( + execute_launch_plan(get_target, + non_trivial_kernel, + dim3(grid_size), + dim3(block_size), + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("run_length_encode::non_trivial_kernel", size, start); diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_run_length_encode_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_run_length_encode_config.hpp index 708e89be3ee..5da14c7f4c5 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_run_length_encode_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_run_length_encode_config.hpp @@ -55,180 +55,94 @@ struct run_length_encode_config namespace detail { -template -struct wrapped_trivial_runs_config - : wrapped_reduce_by_key_config -{}; - -template -struct wrapped_trivial_runs_config< - rocprim::run_length_encode_config, - KeyType, - AccumulatorType, - BinaryFunction> - : wrapped_reduce_by_key_config -{}; - -template -struct wrapped_trivial_runs_impl - : wrapped_reduce_by_key_impl -{}; - -template -struct wrapped_trivial_runs_impl< - KeyType, - AccumulatorType, - BinaryFunction, - std::enable_if_t::value && is_arithmetic::value - && is_binary_functional::value>> +template +struct select_reduce_by_key_config { - template - struct architecture_config - { - static constexpr reduce_by_key_config_params params - = default_trivial_runs_config(Arch), - KeyType, - AccumulatorType>{}; - }; + using type = typename Config::reduce_by_key; }; -template -struct wrapped_trivial_runs_config - : wrapped_trivial_runs_impl -{}; +template<> +struct select_reduce_by_key_config +{ + using type = rocprim::default_config; +}; -// Wrap around run_length_encode_config and the newly added non_trivial_runs_config for the -// run_length_encode_non_trivial_runs algorithm. Three cases are considered for selecting -// the appropriate config: -// -// - When a run_length_encode_config struct is passed as argument, an specialization of -// this struct takes care of mapping the parameters of that config to the newly added -// non_trivial_runs_config. -// -// - When a default config is passed, another specialization takes care of using the -// default set up of non_trivial_runs_config. -// -// - When a non_trivial_runs_config is passed, the params are set from this config. -// -template -struct wrapped_non_trivial_runs_config +template +struct run_length_encode_config_selector { - static_assert(std::is_same::value, - "Config must be a specialization of struct template non_trivial_runs_config"); + using targets = run_length_encode_targets; + using param_type = reduce_by_key_config_params; + + param_type params; - template - struct architecture_config + template + constexpr param_type picker_helper() { - static constexpr non_trivial_runs_config_params params = RLENonTrivialRunsConfig{}; - }; + // Specialization for default config if types are not custom: instantiate the tuned config. + if constexpr(rocprim::is_arithmetic::value && rocprim::is_arithmetic::value + && rocprim::detail::is_binary_functional::value) + { + return run_length_encode_config_picker(); + } + else + { + return reduce_by_key_config_params_base(); + } + } + + template + constexpr run_length_encode_config_selector(Target) : params(picker_helper()) + {} }; -template -struct wrapped_non_trivial_runs_config< - rocprim::run_length_encode_config, - InputType> +template +struct convert_to_non_trivial_config { - template - struct architecture_config - { - // Mapping to non_trivial_runs_config. - // Beware that this mapping may impact performance of executions of - // run_length_encode_non_trivial_runs with the former run_length_encode_config, - // as it may not be the best for all cases. - static constexpr unsigned int block_size = ReduceByKeyConfig::block_size; - static constexpr unsigned int items_per_thread = ReduceByKeyConfig::items_per_thread; - - static constexpr block_load_method load_input_method = ReduceByKeyConfig::load_keys_method; - static constexpr block_scan_algorithm scan_algorithm = ReduceByKeyConfig::scan_algorithm; - - static constexpr non_trivial_runs_config_params params - = non_trivial_runs_config{}; - }; + using ReduceByKeyConfig = typename Config::reduce_by_key; + static constexpr unsigned int block_size = ReduceByKeyConfig::block_size; + static constexpr unsigned int items_per_thread = ReduceByKeyConfig::items_per_thread; + + static constexpr block_load_method load_input_method = ReduceByKeyConfig::load_keys_method; + static constexpr block_scan_algorithm scan_algorithm = ReduceByKeyConfig::scan_algorithm; + + using type + = non_trivial_runs_config; }; -// Generic for default config: instantiate base config. -template -struct wrapped_non_trivial_runs_impl +template<> +struct convert_to_non_trivial_config { - template - struct architecture_config - { - static constexpr non_trivial_runs_config_params params = - typename default_non_trivial_runs_config_base::type{}; - }; + using type = rocprim::default_config; }; -// Specialization for default config if types are arithmetic or half/bfloat16-precision -// floating point types: instantiate the tuned config. -template -struct wrapped_non_trivial_runs_impl::value>> +template +struct run_length_encode_non_trivial_config_selector { - template - struct architecture_config + using targets = run_length_encode_non_trivial_targets; + using param_type = non_trivial_runs_config_params; + + param_type params; + + template + constexpr param_type picker_helper() { - static constexpr non_trivial_runs_config_params params - = default_non_trivial_runs_config(Arch), InputType>{}; - }; + // Specialization for default config if types are not custom: instantiate the tuned config. + if constexpr(rocprim::is_arithmetic::value) + { + return run_length_encode_non_trivial_config_picker(); + } + else + { + return non_trivial_runs_config_params_base(); + } + } + + template + constexpr run_length_encode_non_trivial_config_selector(Target) + : params(picker_helper()) + {} }; -// Specialization for default config. -template -struct wrapped_non_trivial_runs_config - : wrapped_non_trivial_runs_impl -{}; - -#ifndef DOXYGEN_DOCUMENTATION_BUILD - -template -template -constexpr reduce_by_key_config_params wrapped_trivial_runs_impl< - KeyType, - AccumulatorType, - BinaryFunction, - std::enable_if_t::value && is_arithmetic::value - && is_binary_functional::value>>::architecture_config:: - params; - -template -template -constexpr non_trivial_runs_config_params - wrapped_non_trivial_runs_config::architecture_config::params; - -template -template -constexpr non_trivial_runs_config_params wrapped_non_trivial_runs_config< - rocprim::run_length_encode_config, - InputType>::architecture_config::params; - -template -template -constexpr non_trivial_runs_config_params - wrapped_non_trivial_runs_impl::architecture_config::params; - -template -template -constexpr non_trivial_runs_config_params wrapped_non_trivial_runs_impl< - InputType, - std::enable_if_t::value>>::architecture_config::params; - -#endif // DOXYGEN_DOCUMENTATION_BUILD - } // end namespace detail END_ROCPRIM_NAMESPACE diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_segmented_reduce.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_segmented_reduce.hpp index 1b6b011023a..8151620a1df 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_segmented_reduce.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_segmented_reduce.hpp @@ -63,15 +63,17 @@ inline hipError_t segmented_reduce_impl(void* temporary_storage, using input_type = typename std::iterator_traits::value_type; using result_type = ::rocprim::accumulator_t; - using config = wrapped_segmented_reduce_config; + using Selector = segmented_reduce_config_selector; - detail::target_arch target_arch; - hipError_t result = host_target_arch(stream, target_arch); - if(result != hipSuccess) - { - return result; - } - const reduce_config_params params = dispatch_target_arch(target_arch); + target_arch target_arch; + ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); + + gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + + const target get_target(target_arch, target_gpu); + + const auto params = get_config(Config{}, get_target); const unsigned int block_size = params.kernel_config.block_size; @@ -102,12 +104,12 @@ inline hipError_t segmented_reduce_impl(void* temporary_storage, static_cast(initial_value)); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - segmented_reduce_kernel, - dim3(segments), - dim3(block_size), - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(get_target, + segmented_reduce_kernel, + dim3(segments), + dim3(block_size), + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("segmented_reduce", segments, start); return hipSuccess; diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_segmented_reduce_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_segmented_reduce_config.hpp index f64800af03c..0b929a18116 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_segmented_reduce_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_segmented_reduce_config.hpp @@ -33,42 +33,20 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -template -struct wrapped_segmented_reduce_config +template +struct segmented_reduce_config_selector { - static_assert(std::is_same::value, - "Config must be a specialization of struct template reduce_config"); + using targets = segmented_reduce_targets; + using param_type = reduce_config_params; - template - struct architecture_config - { - static constexpr reduce_config_params params = ReduceConfig(); - }; -}; + param_type params; -template -struct wrapped_segmented_reduce_config -{ - template - struct architecture_config - { - static constexpr reduce_config_params params - = default_segmented_reduce_config(Arch), Value>(); - }; + template + constexpr segmented_reduce_config_selector(Target) + : params(segmented_reduce_config_picker()) + {} }; -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template -template -constexpr reduce_config_params - wrapped_segmented_reduce_config::architecture_config::params; - -template -template -constexpr reduce_config_params - wrapped_segmented_reduce_config::architecture_config::params; -#endif // DOXYGEN_SHOULD_SKIP_THIS - } // namespace detail END_ROCPRIM_NAMESPACE diff --git a/projects/rocprim/rocprim/include/rocprim/device/specialization/device_radix_block_sort.hpp b/projects/rocprim/rocprim/include/rocprim/device/specialization/device_radix_block_sort.hpp index ddcd02836ac..f01012020be 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/specialization/device_radix_block_sort.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/specialization/device_radix_block_sort.hpp @@ -52,15 +52,17 @@ inline hipError_t radix_sort_block_sort(KeysInputIterator keys_input, using key_type = typename std::iterator_traits::value_type; using value_type = typename std::iterator_traits::value_type; - using config = wrapped_radix_sort_block_sort_config; + using Selector = radix_sort_block_sort_config_selector; - detail::target_arch target_arch; - hipError_t result = host_target_arch(stream, target_arch); - if(result != hipSuccess) - { - return result; - } - const kernel_config_params params = dispatch_target_arch(target_arch); + target_arch target_arch; + ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); + + gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + + const target get_target(target_arch, target_gpu); + + const auto params = get_config(Config{}, get_target); sort_items_per_block = params.block_size * params.items_per_thread; const unsigned int sort_number_of_blocks = ceiling_div(size, sort_items_per_block); @@ -99,14 +101,14 @@ inline hipError_t radix_sort_block_sort(KeysInputIterator keys_input, }; ROCPRIM_RETURN_ON_ERROR( - execute_launch_plan(target_arch, - radix_sort_block_sort_kernel, - dim3(sort_number_of_blocks), - dim3(params.block_size), - 0, - stream)); + execute_launch_plan( + get_target, + radix_sort_block_sort_kernel, + dim3(sort_number_of_blocks), + dim3(params.block_size), + 0, + stream)); + ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("radix_sort_block_sort_kernel", size, start); return hipSuccess; } diff --git a/projects/rocprim/rocprim/include/rocprim/device/specialization/device_radix_merge_sort.hpp b/projects/rocprim/rocprim/include/rocprim/device/specialization/device_radix_merge_sort.hpp index 35b416a9485..21da5f37cce 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/specialization/device_radix_merge_sort.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/specialization/device_radix_merge_sort.hpp @@ -203,11 +203,13 @@ hipError_t radix_sort_merge_impl( // In the case that the user provides no custom config for merge sort block sort, // instead of using the autotuned merge_sort_block_sort_config, use a hard-coded config that // a power-of-two items sorted per block. - using default_block_sort_config = - typename radix_sort_block_sort_config_base::type; + constexpr auto default_block_sort_config + = radix_sort_block_sort_config_params_base(); + using radix_sort_block_sort_config = typename std::conditional, // extract the relevant config from merge_sort_block_sort_config typename Config::block_sort_config::sort_config>::type; static_assert( From c5d1ebfdfd926954442ee0fd137fe402c00e3a45 Mon Sep 17 00:00:00 2001 From: Nick Breed Date: Fri, 7 Nov 2025 06:38:20 +0000 Subject: [PATCH 04/26] Resolve "Update configs for new config system part 6" --- .../benchmark_device_transform.parallel.hpp | 15 +- .../detail/config/device_binary_search.hpp | 6789 +++++++---------- .../detail/config/device_lower_bound.hpp | 6789 +++++++---------- .../device/detail/config/device_transform.hpp | 1294 ++-- .../config/device_transform_pointer.hpp | 1232 ++- .../detail/config/device_upper_bound.hpp | 6789 +++++++---------- .../device/detail/device_config_helper.hpp | 35 +- .../rocprim/device/device_binary_search.hpp | 123 +- .../device/device_binary_search_config.hpp | 83 +- .../rocprim/device/device_transform.hpp | 52 +- .../device/device_transform_config.hpp | 67 +- 11 files changed, 9601 insertions(+), 13667 deletions(-) diff --git a/projects/rocprim/benchmark/benchmark_device_transform.parallel.hpp b/projects/rocprim/benchmark/benchmark_device_transform.parallel.hpp index 7a0523e0864..1aa283197d2 100644 --- a/projects/rocprim/benchmark/benchmark_device_transform.parallel.hpp +++ b/projects/rocprim/benchmark/benchmark_device_transform.parallel.hpp @@ -65,7 +65,6 @@ template struct device_transform_benchmark : public benchmark_utils::autotune_interface { - std::string name() const override { @@ -122,13 +121,15 @@ struct device_transform_benchmark : public benchmark_utils::autotune_interface { const auto launch = [&] { + using Selector = rocprim::detail::transform_config_selector; auto transform_op = [](T v) { return v + T(5); }; - return rocprim::detail::transform_impl(d_input.get(), - d_output.get(), - size, - transform_op, - stream, - debug_synchronous); + return rocprim::detail::transform_impl( + d_input.get(), + d_output.get(), + size, + transform_op, + stream, + debug_synchronous); }; state.run([&] { HIP_CHECK(launch()); }); diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_binary_search.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_binary_search.hpp index f2705bfd361..bdaf163efd8 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_binary_search.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_binary_search.hpp @@ -40,4045 +40,2756 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_binary_search_config : default_binary_search_config_base -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = double, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : binary_search_config<64, 4> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : binary_search_config<64, 8> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : binary_search_config<64, 8> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : binary_search_config<64, 8> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : binary_search_config<64, 8> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<128, 4> -{}; - -// Based on value_type = short, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<128, 4> -{}; - -// Based on value_type = short, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<128, 4> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : binary_search_config<128, 4> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : binary_search_config<128, 4> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : binary_search_config<256, 16> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : binary_search_config<256, 8> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : binary_search_config<256, 8> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : binary_search_config<256, 4> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : binary_search_config<128, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 16> -{}; - -// Based on value_type = double, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : binary_search_config<64, 16> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<64, 16> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : binary_search_config<64, 8> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : binary_search_config<64, 16> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : binary_search_config<64, 8> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : binary_search_config<64, 8> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : binary_search_config<64, 8> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 16> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 16> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : binary_search_config<64, 16> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 16> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<64, 8> -{}; - -// Based on value_type = short, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<64, 8> -{}; - -// Based on value_type = short, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<64, 8> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : binary_search_config<64, 8> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 4> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : binary_search_config<64, 8> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : binary_search_config<128, 4> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : binary_search_config<64, 8> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : binary_search_config<64, 8> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : binary_search_config<256, 8> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<64, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<64, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : binary_search_config<64, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<64, 16> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = double, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : binary_search_config<128, 1> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : binary_search_config<128, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : binary_search_config<128, 2> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : binary_search_config<64, 4> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : binary_search_config<128, 1> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : binary_search_config<128, 1> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = short, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = short, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : binary_search_config<128, 1> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : binary_search_config<64, 2> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : binary_search_config<64, 8> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = double, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : binary_search_config<128, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : binary_search_config<64, 1> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : binary_search_config<64, 2> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : binary_search_config<128, 2> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = short, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = short, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : binary_search_config<64, 1> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : binary_search_config<64, 1> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : binary_search_config<64, 1> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = double, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 16> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : binary_search_config<64, 4> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : binary_search_config<256, 4> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : binary_search_config<256, 4> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : binary_search_config<256, 4> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 4> -{}; - -// Based on value_type = short, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 4> -{}; - -// Based on value_type = short, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 4> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 4> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 16> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : binary_search_config<64, 4> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : binary_search_config<256, 4> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : binary_search_config<256, 4> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : binary_search_config<256, 4> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : binary_search_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 8> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<64, 4> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : binary_search_config<64, 4> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 16> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = double, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : binary_search_config<128, 1> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<64, 1> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : binary_search_config<64, 4> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : binary_search_config<64, 4> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : binary_search_config<256, 4> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : binary_search_config<64, 4> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 8> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = short, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = short, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<64, 4> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : binary_search_config<128, 1> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 16> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : binary_search_config<128, 2> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : binary_search_config<128, 1> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : binary_search_config<256, 4> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : binary_search_config<256, 4> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : binary_search_config<256, 4> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : binary_search_config<128, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<64, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<64, 1> -{}; - -// Based on value_type = double, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<64, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : binary_search_config<64, 1> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<64, 1> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : binary_search_config<64, 2> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : binary_search_config<64, 2> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : binary_search_config<64, 2> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : binary_search_config<128, 2> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : binary_search_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<64, 1> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = short, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = short, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : binary_search_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<64, 2> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = double, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : binary_search_config<128, 1> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<64, 1> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : binary_search_config<64, 4> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : binary_search_config<64, 4> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : binary_search_config<256, 4> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : binary_search_config<64, 4> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 8> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = short, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = short, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<64, 4> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : binary_search_config<128, 1> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 16> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : binary_search_config<128, 2> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : binary_search_config<128, 1> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : binary_search_config<256, 4> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : binary_search_config<256, 4> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : binary_search_config<256, 4> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : binary_search_config<128, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = double, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<128, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : binary_search_config<128, 1> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : binary_search_config<256, 8> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : binary_search_config<256, 8> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : binary_search_config<64, 8> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : binary_search_config<256, 4> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : binary_search_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : binary_search_config<64, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = short, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = short, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : binary_search_config<256, 8> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : binary_search_config<128, 4> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : binary_search_config<64, 1> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : binary_search_config<64, 4> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : binary_search_config<256, 4> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : binary_search_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : binary_search_config<256, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : binary_search_config<256, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : binary_search_config<256, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : binary_search_config<256, 16> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 1> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : binary_search_config<256, 2> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_binary_search_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : binary_search_config<128, 1> -{}; +template +constexpr auto binary_search_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = rocprim::half, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = short, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = int8_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto binary_search_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = double, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = rocprim::half, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = int64_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = short, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = int8_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 16} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto binary_search_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 2} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 2} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto binary_search_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 2} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 2} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = int64_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 2} + }; + } + // Based on value_type = int8_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto binary_search_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = rocprim::half, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = short, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = int8_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = rocprim::int128_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 16} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto binary_search_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = rocprim::half, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = short, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 2} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = int8_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto binary_search_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = double, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 2} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 2} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 2} + }; + } + // Based on value_type = rocprim::half, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 2} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 2} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto binary_search_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::half, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto binary_search_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + return binary_search_config_picker< + comp_target, + value_type, + output_type>(); +} + +// All the existing configs should be auto generated +using binary_search_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_lower_bound.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_lower_bound.hpp index 13e05329597..ed8e240ca2f 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_lower_bound.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_lower_bound.hpp @@ -40,4045 +40,2756 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_lower_bound_config : default_binary_search_config_base -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : lower_bound_config<64, 8> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : lower_bound_config<64, 8> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : lower_bound_config<64, 8> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : lower_bound_config<64, 4> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : lower_bound_config<64, 8> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<128, 4> -{}; - -// Based on value_type = short, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<128, 4> -{}; - -// Based on value_type = short, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<128, 4> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : lower_bound_config<128, 4> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : lower_bound_config<128, 4> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : lower_bound_config<128, 4> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : lower_bound_config<128, 4> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : lower_bound_config<128, 4> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : lower_bound_config<256, 4> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : lower_bound_config<128, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 16> -{}; - -// Based on value_type = double, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : lower_bound_config<64, 16> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<64, 16> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : lower_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : lower_bound_config<64, 8> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : lower_bound_config<128, 8> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : lower_bound_config<128, 8> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : lower_bound_config<128, 8> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : lower_bound_config<64, 8> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<64, 16> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : lower_bound_config<64, 16> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<128, 16> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : lower_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<64, 16> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<64, 16> -{}; - -// Based on value_type = short, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<64, 8> -{}; - -// Based on value_type = short, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<64, 8> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : lower_bound_config<64, 8> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 8> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : lower_bound_config<256, 8> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : lower_bound_config<128, 4> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : lower_bound_config<256, 8> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : lower_bound_config<64, 8> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : lower_bound_config<256, 8> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<64, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<64, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : lower_bound_config<64, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<64, 16> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = double, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : lower_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : lower_bound_config<128, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : lower_bound_config<256, 8> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : lower_bound_config<64, 2> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : lower_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : lower_bound_config<128, 1> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : lower_bound_config<128, 4> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : lower_bound_config<64, 2> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = double, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : lower_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : lower_bound_config<128, 2> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : lower_bound_config<128, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : lower_bound_config<128, 1> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : lower_bound_config<64, 2> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : lower_bound_config<128, 2> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 2> -{}; - -// Based on value_type = short, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : lower_bound_config<64, 1> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : lower_bound_config<128, 4> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : lower_bound_config<128, 2> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 16> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : lower_bound_config<64, 4> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : lower_bound_config<256, 4> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : lower_bound_config<256, 4> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 16> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 4> -{}; - -// Based on value_type = short, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 4> -{}; - -// Based on value_type = short, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 4> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<128, 16> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : lower_bound_config<256, 4> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : lower_bound_config<64, 8> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : lower_bound_config<256, 4> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : lower_bound_config<256, 4> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : lower_bound_config<64, 4> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 8> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<64, 4> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : lower_bound_config<64, 4> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 16> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<128, 2> -{}; - -// Based on value_type = double, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 16> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : lower_bound_config<128, 1> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : lower_bound_config<256, 2> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : lower_bound_config<64, 4> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : lower_bound_config<64, 4> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : lower_bound_config<256, 4> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<64, 16> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<64, 4> -{}; - -// Based on value_type = int, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 2> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 4> -{}; - -// Based on value_type = short, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 4> -{}; - -// Based on value_type = short, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 4> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 4> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 16> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : lower_bound_config<256, 4> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : lower_bound_config<256, 4> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : lower_bound_config<64, 4> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : lower_bound_config<256, 4> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : lower_bound_config<64, 4> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : lower_bound_config<128, 8> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = double, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : lower_bound_config<64, 1> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : lower_bound_config<128, 4> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : lower_bound_config<64, 2> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : lower_bound_config<64, 2> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : lower_bound_config<64, 2> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : lower_bound_config<64, 2> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : lower_bound_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : lower_bound_config<128, 2> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : lower_bound_config<256, 2> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : lower_bound_config<256, 2> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : lower_bound_config<128, 2> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : lower_bound_config<256, 2> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : lower_bound_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<64, 2> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<64, 8> -{}; - -// Based on value_type = double, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<64, 8> -{}; - -// Based on value_type = double, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 16> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : lower_bound_config<256, 2> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : lower_bound_config<64, 4> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : lower_bound_config<64, 4> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : lower_bound_config<256, 4> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<64, 8> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 16> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 2> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = short, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = short, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 4> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : lower_bound_config<64, 4> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 16> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : lower_bound_config<128, 1> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : lower_bound_config<128, 1> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : lower_bound_config<128, 1> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : lower_bound_config<64, 4> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : lower_bound_config<64, 4> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : lower_bound_config<128, 8> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = double, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : lower_bound_config<64, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : lower_bound_config<256, 4> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : lower_bound_config<256, 8> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : lower_bound_config<256, 2> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : lower_bound_config<256, 2> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : lower_bound_config<256, 4> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : lower_bound_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : lower_bound_config<64, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : lower_bound_config<128, 1> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : lower_bound_config<256, 4> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : lower_bound_config<128, 4> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : lower_bound_config<64, 1> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : lower_bound_config<64, 8> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : lower_bound_config<128, 4> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : lower_bound_config<128, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : lower_bound_config<256, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : lower_bound_config<256, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : lower_bound_config<256, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : lower_bound_config<256, 16> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<128, 1> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : lower_bound_config<256, 2> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_lower_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : lower_bound_config<128, 1> -{}; +template +constexpr auto lower_bound_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = rocprim::half, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = short, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = int8_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto lower_bound_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = double, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 8} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 8} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 8} + }; + } + // Based on value_type = rocprim::half, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = int64_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 16} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = short, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = int8_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 16} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto lower_bound_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 2} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 2} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto lower_bound_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 2} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 2} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 2} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 2} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 2} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::half, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = int64_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = short, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 2} + }; + } + // Based on value_type = int8_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto lower_bound_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = rocprim::half, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 16} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = int8_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = rocprim::int128_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 16} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto lower_bound_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 2} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 2} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = rocprim::half, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 2} + }; + } + // Based on value_type = int, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = short, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = int8_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 8} + }; + } + // Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto lower_bound_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = double, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 2} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 2} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 2} + }; + } + // Based on value_type = rocprim::half, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 2} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 2} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 2} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 2} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 2} + }; + } + // Based on value_type = int8_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 2} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 2} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto lower_bound_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = float, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 2} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 2} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = short, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 4} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto lower_bound_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + return lower_bound_config_picker< + comp_target, + value_type, + output_type>(); +} + +// All the existing configs should be auto generated +using lower_bound_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_transform.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_transform.hpp index e513592ce57..12ed641df43 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_transform.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_transform.hpp @@ -40,702 +40,604 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_transform_config : default_transform_config_base::type -{}; - -// Based on value_type = double -template -struct default_transform_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_config<512, 2> -{}; - -// Based on value_type = float -template -struct default_transform_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_config<256, 1> -{}; - -// Based on value_type = rocprim::half -template -struct default_transform_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : transform_config<1024, 1> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_transform_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : transform_config<1024, 1> -{}; - -// Based on value_type = int64_t -template -struct default_transform_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_config<512, 2> -{}; - -// Based on value_type = int -template -struct default_transform_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_config<256, 1> -{}; - -// Based on value_type = short -template -struct default_transform_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : transform_config<1024, 1> -{}; - -// Based on value_type = int8_t -template -struct default_transform_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : transform_config<128, 2> -{}; - -// Based on value_type = double -template -struct default_transform_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_config<512, 1> -{}; - -// Based on value_type = float -template -struct default_transform_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_config<1024, 1> -{}; - -// Based on value_type = rocprim::half -template -struct default_transform_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : transform_config<1024, 2> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_transform_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : transform_config<1024, 1> -{}; - -// Based on value_type = int64_t -template -struct default_transform_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_config<512, 1> -{}; - -// Based on value_type = int -template -struct default_transform_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_config<64, 1> -{}; - -// Based on value_type = short -template -struct default_transform_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : transform_config<1024, 2> -{}; - -// Based on value_type = int8_t -template -struct default_transform_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : transform_config<1024, 4> -{}; - -// Based on value_type = double -template -struct default_transform_config< - static_cast(target_arch::gfx1200), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_config<512, 2> -{}; - -// Based on value_type = float -template -struct default_transform_config< - static_cast(target_arch::gfx1200), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_config<512, 1> -{}; - -// Based on value_type = rocprim::half -template -struct default_transform_config< - static_cast(target_arch::gfx1200), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : transform_config<512, 2> -{}; - -// Based on value_type = int64_t -template -struct default_transform_config< - static_cast(target_arch::gfx1200), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_config<512, 2> -{}; - -// Based on value_type = int -template -struct default_transform_config< - static_cast(target_arch::gfx1200), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_config<512, 1> -{}; - -// Based on value_type = short -template -struct default_transform_config< - static_cast(target_arch::gfx1200), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : transform_config<512, 2> -{}; - -// Based on value_type = int8_t -template -struct default_transform_config< - static_cast(target_arch::gfx1200), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : transform_config<256, 4> -{}; - -// Based on value_type = double -template -struct default_transform_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_config<64, 2> -{}; - -// Based on value_type = float -template -struct default_transform_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_config<512, 4> -{}; - -// Based on value_type = rocprim::half -template -struct default_transform_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : transform_config<512, 2> -{}; - -// Based on value_type = int64_t -template -struct default_transform_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_config<1024, 2> -{}; - -// Based on value_type = int -template -struct default_transform_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_config<1024, 4> -{}; - -// Based on value_type = short -template -struct default_transform_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : transform_config<1024, 8> -{}; - -// Based on value_type = int8_t -template -struct default_transform_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : transform_config<256, 4> -{}; - -// Based on value_type = double -template -struct default_transform_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_config<1024, 1> -{}; - -// Based on value_type = float -template -struct default_transform_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_config<1024, 2> -{}; - -// Based on value_type = rocprim::half -template -struct default_transform_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : transform_config<512, 4> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_transform_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : transform_config<1024, 1> -{}; - -// Based on value_type = int64_t -template -struct default_transform_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_config<512, 1> -{}; - -// Based on value_type = int -template -struct default_transform_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_config<1024, 2> -{}; - -// Based on value_type = short -template -struct default_transform_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : transform_config<512, 4> -{}; - -// Based on value_type = int8_t -template -struct default_transform_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : transform_config<64, 16> -{}; - -// Based on value_type = double -template -struct default_transform_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_config<128, 1> -{}; - -// Based on value_type = float -template -struct default_transform_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_config<128, 2> -{}; - -// Based on value_type = rocprim::half -template -struct default_transform_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : transform_config<128, 4> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_transform_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : transform_config<128, 1> -{}; - -// Based on value_type = int64_t -template -struct default_transform_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_config<128, 1> -{}; - -// Based on value_type = int -template -struct default_transform_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_config<128, 2> -{}; - -// Based on value_type = short -template -struct default_transform_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : transform_config<128, 4> -{}; - -// Based on value_type = int8_t -template -struct default_transform_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : transform_config<128, 8> -{}; - -// Based on value_type = double -template -struct default_transform_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_config<256, 2> -{}; - -// Based on value_type = float -template -struct default_transform_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_config<1024, 2> -{}; - -// Based on value_type = rocprim::half -template -struct default_transform_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : transform_config<64, 8> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_transform_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : transform_config<1024, 1> -{}; - -// Based on value_type = int64_t -template -struct default_transform_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_config<256, 2> -{}; - -// Based on value_type = int -template -struct default_transform_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_config<1024, 2> -{}; - -// Based on value_type = short -template -struct default_transform_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : transform_config<64, 8> -{}; - -// Based on value_type = int8_t -template -struct default_transform_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : transform_config<64, 16> -{}; - -// Based on value_type = double -template -struct default_transform_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_config<128, 1> -{}; - -// Based on value_type = float -template -struct default_transform_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_config<128, 2> -{}; - -// Based on value_type = rocprim::half -template -struct default_transform_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : transform_config<128, 4> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_transform_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : transform_config<128, 1> -{}; - -// Based on value_type = int64_t -template -struct default_transform_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_config<128, 1> -{}; - -// Based on value_type = int -template -struct default_transform_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_config<128, 2> -{}; - -// Based on value_type = short -template -struct default_transform_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : transform_config<128, 4> -{}; - -// Based on value_type = int8_t -template -struct default_transform_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : transform_config<128, 8> -{}; - -// Based on value_type = double -template -struct default_transform_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_config<512, 4> -{}; - -// Based on value_type = float -template -struct default_transform_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_config<256, 4> -{}; - -// Based on value_type = rocprim::half -template -struct default_transform_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : transform_config<256, 8> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_transform_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : transform_config<512, 1> -{}; - -// Based on value_type = int64_t -template -struct default_transform_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_config<512, 2> -{}; - -// Based on value_type = int -template -struct default_transform_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_config<512, 4> -{}; - -// Based on value_type = short -template -struct default_transform_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : transform_config<256, 8> -{}; - -// Based on value_type = int8_t -template -struct default_transform_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : transform_config<1024, 8> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_transform_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : transform_config<64, 1> -{}; +template +constexpr auto transform_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {512, 2} + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return transform_config_params{ + {1024, 1} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return transform_config_params{ + {1024, 1} + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {512, 2} + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return transform_config_params{ + {1024, 1} + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return transform_config_params{ + {128, 2} + }; + } + // Default case if none of the conditions match + return transform_config_params_base(); +} + +template +constexpr auto transform_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {512, 1} + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {1024, 1} + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return transform_config_params{ + {1024, 2} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return transform_config_params{ + {1024, 1} + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {512, 1} + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return transform_config_params{ + {1024, 2} + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return transform_config_params{ + {1024, 4} + }; + } + // Default case if none of the conditions match + return transform_config_params_base(); +} + +template +constexpr auto transform_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {512, 2} + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {512, 1} + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return transform_config_params{ + {512, 2} + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {512, 2} + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {512, 1} + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return transform_config_params{ + {512, 2} + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Default case if none of the conditions match + return transform_config_params_base(); +} + +template +constexpr auto transform_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {64, 2} + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {512, 4} + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return transform_config_params{ + {512, 2} + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {1024, 2} + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {1024, 4} + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return transform_config_params{ + {1024, 8} + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return transform_config_params{ + {64, 1} + }; + } + // Default case if none of the conditions match + return transform_config_params_base(); +} + +template +constexpr auto transform_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {1024, 1} + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {1024, 2} + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return transform_config_params{ + {512, 4} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return transform_config_params{ + {1024, 1} + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {512, 1} + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {1024, 2} + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return transform_config_params{ + {512, 4} + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return transform_config_params{ + {64, 16} + }; + } + // Default case if none of the conditions match + return transform_config_params_base(); +} + +template +constexpr auto transform_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {128, 2} + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {128, 2} + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return transform_config_params{ + {128, 8} + }; + } + // Default case if none of the conditions match + return transform_config_params_base(); +} + +template +constexpr auto transform_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {256, 2} + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {1024, 2} + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return transform_config_params{ + {1024, 1} + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {256, 2} + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {1024, 2} + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return transform_config_params{ + {64, 16} + }; + } + // Default case if none of the conditions match + return transform_config_params_base(); +} + +template +constexpr auto transform_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {512, 4} + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return transform_config_params{ + {512, 1} + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {512, 2} + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {512, 4} + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return transform_config_params{ + {1024, 8} + }; + } + // Default case if none of the conditions match + return transform_config_params_base(); +} + +template +constexpr auto transform_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + return transform_config_picker< + comp_target, + value_type>(); +} + +// All the existing configs should be auto generated +using transform_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_transform_pointer.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_transform_pointer.hpp index bf2e4bc3c8c..e936d7581ea 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_transform_pointer.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_transform_pointer.hpp @@ -40,650 +40,594 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_transform_pointer_config : default_transform_pointer_config_base::type -{}; - -// Based on value_type = double -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_pointer_config<1024, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = float -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_pointer_config<1024, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = rocprim::half -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : transform_pointer_config<1024, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : transform_pointer_config<1024, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = int64_t -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_pointer_config<1024, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = int -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_pointer_config<1024, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = short -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : transform_pointer_config<1024, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = int8_t -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : transform_pointer_config<1024, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = double -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_pointer_config<1024, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = float -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_pointer_config<1024, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = rocprim::half -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : transform_pointer_config<1024, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : transform_pointer_config<1024, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = int64_t -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_pointer_config<1024, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = int -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_pointer_config<1024, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = short -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : transform_pointer_config<1024, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = int8_t -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : transform_pointer_config<1024, 16, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = double -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_pointer_config<1024, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = float -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_pointer_config<1024, 4, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = rocprim::half -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : transform_pointer_config<512, 4, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : transform_pointer_config<1024, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = int64_t -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_pointer_config<512, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = int -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_pointer_config<1024, 4, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = short -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : transform_pointer_config<1024, 8, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = int8_t -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : transform_pointer_config<512, 8, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = double -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_pointer_config<128, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = float -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_pointer_config<128, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = rocprim::half -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : transform_pointer_config<128, 4, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : transform_pointer_config<128, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = int64_t -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_pointer_config<128, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = int -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_pointer_config<128, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = short -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : transform_pointer_config<128, 4, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = int8_t -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : transform_pointer_config<128, 8, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = double -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_pointer_config<1024, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = float -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_pointer_config<1024, 4, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = rocprim::half -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : transform_pointer_config<1024, 8, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : transform_pointer_config<1024, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = int64_t -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_pointer_config<1024, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = int -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_pointer_config<1024, 4, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = short -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : transform_pointer_config<1024, 8, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = int8_t -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : transform_pointer_config<1024, 16, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = double -template -struct default_transform_pointer_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_pointer_config<128, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = float -template -struct default_transform_pointer_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_pointer_config<128, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = rocprim::half -template -struct default_transform_pointer_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : transform_pointer_config<128, 4, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_transform_pointer_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : transform_pointer_config<128, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = int64_t -template -struct default_transform_pointer_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_pointer_config<128, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = int -template -struct default_transform_pointer_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_pointer_config<128, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = short -template -struct default_transform_pointer_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : transform_pointer_config<128, 4, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = int8_t -template -struct default_transform_pointer_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : transform_pointer_config<128, 8, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = double -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_pointer_config<1024, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = float -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_pointer_config<256, 4, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = rocprim::half -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : transform_pointer_config<256, 8, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : transform_pointer_config<256, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = int64_t -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_pointer_config<512, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = int -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_pointer_config<256, 4, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = short -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : transform_pointer_config<256, 8, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = int8_t -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : transform_pointer_config<256, 16, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = double -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_pointer_config<1024, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = float -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_pointer_config<1024, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = rocprim::half -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> - : transform_pointer_config<1024, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : transform_pointer_config<64, 1, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_default> -{}; - -// Based on value_type = int64_t -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : transform_pointer_config<1024, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = int -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : transform_pointer_config<1024, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = short -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : transform_pointer_config<64, 2, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; - -// Based on value_type = int8_t -template -struct default_transform_pointer_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> - : transform_pointer_config<1024, 16, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::load_nontemporal> -{}; +template +constexpr auto transform_pointer_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {1024, 1}, + ::rocprim::load_default + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {1024, 1}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return transform_config_params{ + {1024, 1}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return transform_config_params{ + {1024, 1}, + ::rocprim::load_default + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {1024, 1}, + ::rocprim::load_default + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {1024, 1}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return transform_config_params{ + {1024, 1}, + ::rocprim::load_default + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return transform_config_params{ + {1024, 2}, + ::rocprim::load_default + }; + } + // Default case if none of the conditions match + return transform_pointer_config_params_base(); +} + +template +constexpr auto transform_pointer_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {1024, 2}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {1024, 2}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return transform_config_params{ + {1024, 2}, + ::rocprim::load_default + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return transform_config_params{ + {1024, 1}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {1024, 2}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {1024, 2}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return transform_config_params{ + {1024, 2}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return transform_config_params{ + {1024, 16}, + ::rocprim::load_nontemporal + }; + } + // Default case if none of the conditions match + return transform_pointer_config_params_base(); +} + +template +constexpr auto transform_pointer_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {1024, 1}, + ::rocprim::load_default + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {1024, 4}, + ::rocprim::load_default + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return transform_config_params{ + {512, 4}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return transform_config_params{ + {1024, 1}, + ::rocprim::load_default + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {512, 1}, + ::rocprim::load_default + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {1024, 4}, + ::rocprim::load_default + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return transform_config_params{ + {1024, 8}, + ::rocprim::load_default + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return transform_config_params{ + {512, 8}, + ::rocprim::load_nontemporal + }; + } + // Default case if none of the conditions match + return transform_pointer_config_params_base(); +} + +template +constexpr auto transform_pointer_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {128, 1}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {128, 2}, + ::rocprim::load_default + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return transform_config_params{ + {128, 4}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return transform_config_params{ + {128, 1}, + ::rocprim::load_default + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {128, 1}, + ::rocprim::load_default + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {128, 2}, + ::rocprim::load_default + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return transform_config_params{ + {128, 4}, + ::rocprim::load_default + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return transform_config_params{ + {128, 8}, + ::rocprim::load_nontemporal + }; + } + // Default case if none of the conditions match + return transform_pointer_config_params_base(); +} + +template +constexpr auto transform_pointer_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {1024, 2}, + ::rocprim::load_default + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {1024, 4}, + ::rocprim::load_default + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return transform_config_params{ + {1024, 8}, + ::rocprim::load_default + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return transform_config_params{ + {1024, 1}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {1024, 2}, + ::rocprim::load_default + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {1024, 4}, + ::rocprim::load_default + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return transform_config_params{ + {1024, 8}, + ::rocprim::load_default + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return transform_config_params{ + {1024, 16}, + ::rocprim::load_default + }; + } + // Default case if none of the conditions match + return transform_pointer_config_params_base(); +} + +template +constexpr auto transform_pointer_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {1024, 2}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {256, 4}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return transform_config_params{ + {256, 8}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return transform_config_params{ + {256, 1}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {512, 2}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {256, 4}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return transform_config_params{ + {256, 8}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return transform_config_params{ + {256, 16}, + ::rocprim::load_nontemporal + }; + } + // Default case if none of the conditions match + return transform_pointer_config_params_base(); +} + +template +constexpr auto transform_pointer_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {1024, 2}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {1024, 2}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return transform_config_params{ + {1024, 2}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return transform_config_params{ + {64, 1}, + ::rocprim::load_default + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return transform_config_params{ + {1024, 2}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return transform_config_params{ + {1024, 2}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return transform_config_params{ + {64, 2}, + ::rocprim::load_nontemporal + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return transform_config_params{ + {1024, 16}, + ::rocprim::load_nontemporal + }; + } + // Default case if none of the conditions match + return transform_pointer_config_params_base(); +} + +template +constexpr auto transform_pointer_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + return transform_pointer_config_picker< + comp_target, + value_type>(); +} + +// All the existing configs should be auto generated +using transform_pointer_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_upper_bound.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_upper_bound.hpp index 1a2df25c825..5c17ef4e819 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_upper_bound.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_upper_bound.hpp @@ -40,4045 +40,2756 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_upper_bound_config : default_binary_search_config_base -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : upper_bound_config<64, 8> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : upper_bound_config<64, 8> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : upper_bound_config<64, 8> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : upper_bound_config<64, 8> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : upper_bound_config<64, 8> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<128, 4> -{}; - -// Based on value_type = short, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<128, 4> -{}; - -// Based on value_type = short, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<128, 4> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : upper_bound_config<128, 4> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : upper_bound_config<128, 4> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : upper_bound_config<128, 8> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : upper_bound_config<128, 8> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : upper_bound_config<128, 8> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : upper_bound_config<256, 8> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : upper_bound_config<128, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1030), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<64, 16> -{}; - -// Based on value_type = double, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : upper_bound_config<64, 16> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<64, 16> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : upper_bound_config<64, 8> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : upper_bound_config<64, 8> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : upper_bound_config<64, 8> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : upper_bound_config<64, 8> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : upper_bound_config<64, 8> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<64, 16> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<128, 16> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : upper_bound_config<64, 16> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<128, 16> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<64, 16> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<64, 16> -{}; - -// Based on value_type = short, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 8> -{}; - -// Based on value_type = short, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<64, 8> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : upper_bound_config<64, 8> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 8> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : upper_bound_config<64, 8> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : upper_bound_config<128, 4> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : upper_bound_config<64, 8> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : upper_bound_config<64, 8> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : upper_bound_config<256, 8> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<128, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<128, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : upper_bound_config<64, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1100), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<128, 16> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = double, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : upper_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : upper_bound_config<128, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : upper_bound_config<256, 8> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : upper_bound_config<64, 2> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : upper_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : upper_bound_config<128, 1> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : upper_bound_config<128, 4> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : upper_bound_config<128, 2> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1200), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = double, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : upper_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : upper_bound_config<128, 2> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : upper_bound_config<128, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : upper_bound_config<64, 1> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : upper_bound_config<128, 2> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : upper_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : upper_bound_config<128, 2> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 2> -{}; - -// Based on value_type = short, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : upper_bound_config<128, 1> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : upper_bound_config<128, 4> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : upper_bound_config<64, 2> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : upper_bound_config<64, 4> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : upper_bound_config<64, 4> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : upper_bound_config<256, 4> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : upper_bound_config<64, 4> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 16> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 4> -{}; - -// Based on value_type = short, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 4> -{}; - -// Based on value_type = short, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 4> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 4> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<128, 16> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : upper_bound_config<256, 4> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : upper_bound_config<256, 4> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : upper_bound_config<256, 4> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : upper_bound_config<64, 4> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 8> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<64, 4> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : upper_bound_config<64, 4> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx906), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 16> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 8> -{}; - -// Based on value_type = double, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<64, 8> -{}; - -// Based on value_type = double, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<128, 16> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : upper_bound_config<256, 4> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : upper_bound_config<64, 4> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : upper_bound_config<256, 4> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : upper_bound_config<64, 4> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 16> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : upper_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = short, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = short, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<64, 4> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : upper_bound_config<64, 4> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 16> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : upper_bound_config<128, 1> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : upper_bound_config<256, 4> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : upper_bound_config<64, 4> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : upper_bound_config<256, 4> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : upper_bound_config<256, 4> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : upper_bound_config<128, 8> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx908), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 16> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = double, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : upper_bound_config<64, 1> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : upper_bound_config<128, 4> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : upper_bound_config<64, 2> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : upper_bound_config<64, 2> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : upper_bound_config<64, 4> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : upper_bound_config<128, 2> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : upper_bound_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : upper_bound_config<256, 2> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : upper_bound_config<256, 2> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : upper_bound_config<256, 2> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : upper_bound_config<256, 2> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : upper_bound_config<256, 2> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : upper_bound_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx90a), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<64, 2> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 8> -{}; - -// Based on value_type = double, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<64, 8> -{}; - -// Based on value_type = double, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<128, 16> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : upper_bound_config<256, 4> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : upper_bound_config<64, 4> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : upper_bound_config<256, 4> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : upper_bound_config<64, 4> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 16> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : upper_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = short, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = short, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<64, 4> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : upper_bound_config<64, 4> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 16> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : upper_bound_config<128, 1> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : upper_bound_config<256, 4> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : upper_bound_config<64, 4> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : upper_bound_config<256, 4> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : upper_bound_config<256, 4> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : upper_bound_config<128, 8> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::unknown), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 16> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = double, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = double, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = double, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = double, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : upper_bound_config<64, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : upper_bound_config<256, 4> -{}; - -// Based on value_type = rocprim::half, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : upper_bound_config<256, 8> -{}; - -// Based on value_type = rocprim::half, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : upper_bound_config<256, 8> -{}; - -// Based on value_type = rocprim::half, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : upper_bound_config<256, 2> -{}; - -// Based on value_type = rocprim::half, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 1))>> - : upper_bound_config<256, 4> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : upper_bound_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<64, 1> -{}; - -// Based on value_type = int64_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 1))>> : upper_bound_config<64, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = int, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = short, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 1> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : upper_bound_config<128, 4> -{}; - -// Based on value_type = int8_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 8) - && (sizeof(output_type) > 4))>> : upper_bound_config<128, 4> -{}; - -// Based on value_type = int8_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 4) - && (sizeof(output_type) > 2))>> : upper_bound_config<64, 1> -{}; - -// Based on value_type = int8_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 2) - && (sizeof(output_type) > 1))>> : upper_bound_config<128, 4> -{}; - -// Based on value_type = int8_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx942), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 1))>> - : upper_bound_config<256, 4> -{}; - -// Based on value_type = double, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = float, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = rocprim::half, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : upper_bound_config<128, 1> -{}; - -// Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int64_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))>> - : upper_bound_config<256, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))>> - : upper_bound_config<256, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = short -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))>> - : upper_bound_config<256, 16> -{}; - -// Based on value_type = rocprim::int128_t, output_type = int8_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) - && (sizeof(output_type) <= 1))>> : upper_bound_config<256, 16> -{}; - -// Based on value_type = int64_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 1> -{}; - -// Based on value_type = int, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<128, 1> -{}; - -// Based on value_type = short, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) - && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))>> - : upper_bound_config<256, 2> -{}; - -// Based on value_type = int8_t, output_type = rocprim::int128_t -template -struct default_upper_bound_config< - static_cast(target_arch::gfx1201), - value_type, - output_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (sizeof(output_type) <= 16) - && (sizeof(output_type) > 8))>> : upper_bound_config<64, 1> -{}; +template +constexpr auto upper_bound_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = rocprim::half, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = short, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 8} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 8} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 8} + }; + } + // Based on value_type = int8_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto upper_bound_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = double, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = rocprim::half, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 16} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = int64_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 16} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = short, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = int8_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 16} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto upper_bound_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 2} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 2} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto upper_bound_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 2} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 2} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 2} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 2} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 2} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = double, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::half, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = int64_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = short, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 2} + }; + } + // Based on value_type = int8_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 1} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto upper_bound_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = rocprim::half, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = short, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 16} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = int8_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = rocprim::int128_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 16} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto upper_bound_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 8} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 16} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = double, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = rocprim::half, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int64_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = short, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = int8_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 16} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {128, 8} + }; + } + // Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 16} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto upper_bound_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = double, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 2} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 2} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 4} + }; + } + // Based on value_type = rocprim::half, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 2} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 2} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 2} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 2} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 2} + }; + } + // Based on value_type = int8_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 2} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 2} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto upper_bound_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + // Based on value_type = double, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = double, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = double, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = double, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = double, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = float, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = float, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = rocprim::half, output_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = rocprim::half, output_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = rocprim::half, output_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 8} + }; + } + // Based on value_type = rocprim::half, output_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 2} + }; + } + // Based on value_type = rocprim::half, output_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Based on value_type = rocprim::int128_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = rocprim::int128_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int64_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 1} + }; + } + // Based on value_type = int, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 16) + && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 8) + && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 4) + && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 2) + && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = short, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 1} + }; + } + // Based on value_type = int8_t, output_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 16) && (sizeof(output_type) > 8))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = int8_t, output_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 8) && (sizeof(output_type) > 4))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = int8_t, output_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 4) && (sizeof(output_type) > 2))) + { + return transform_config_params{ + {64, 1} + }; + } + // Based on value_type = int8_t, output_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 2) && (sizeof(output_type) > 1))) + { + return transform_config_params{ + {128, 4} + }; + } + // Based on value_type = int8_t, output_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (sizeof(output_type) <= 1))) + { + return transform_config_params{ + {256, 4} + }; + } + // Default case if none of the conditions match + return binary_search_config_params_base(); +} + +template +constexpr auto upper_bound_config_picker() -> std::enable_if_t< + std::is_same>::value, + transform_config_params> +{ + return upper_bound_config_picker< + comp_target, + value_type, + output_type>(); +} + +// All the existing configs should be auto generated +using upper_bound_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/device_config_helper.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/device_config_helper.hpp index e37f7232dce..474af8edd7c 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/device_config_helper.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/device_config_helper.hpp @@ -452,8 +452,8 @@ struct transform_config_tag struct transform_config_params { - kernel_config_params kernel_config{}; - cache_load_modifier load_type; + kernel_config_params kernel_config = {0, 0}; + cache_load_modifier load_type = load_default; }; } // namespace detail @@ -725,22 +725,26 @@ namespace detail { template -struct default_transform_pointer_config_base +constexpr transform_config_params transform_pointer_config_params_base() { - static constexpr unsigned int item_scale + constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div(sizeof(uint128_t), sizeof(Value)); - using type = transform_config<256, item_scale>; -}; + return transform_config_params{ + {256, item_scale} + }; +} template -struct default_transform_config_base +constexpr transform_config_params transform_config_params_base() { - static constexpr unsigned int item_scale + constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div(sizeof(Value), sizeof(int)); - using type = transform_config<256, ::rocprim::max(1u, 16u / item_scale)>; -}; + return transform_config_params{ + {256, ::rocprim::max(1u, 16u / item_scale)} + }; +} struct binary_search_config_tag : public transform_config_tag {}; @@ -797,11 +801,12 @@ struct histogram_config_tag {}; template -struct default_binary_search_config_base - : binary_search_config< - limit_block_size<256U, sizeof(Value) + sizeof(Output), ROCPRIM_WARP_SIZE_64>::value, - 1> -{}; +constexpr transform_config_params binary_search_config_params_base() +{ + return transform_config_params{ + {limit_block_size<256U, sizeof(Value) + sizeof(Output), ROCPRIM_WARP_SIZE_64>::value, 1} + }; +} /// \brief Provides the kernel parameters for histogram_even, multi_histogram_even, /// histogram_range, and multi_histogram_range based on autotuned configurations or diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_binary_search.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_binary_search.hpp index e4dd5f2218e..effaf4ccf5d 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_binary_search.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_binary_search.hpp @@ -39,26 +39,24 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -template< - class Config, - class HaystackIterator, - class NeedlesIterator, - class OutputIterator, - class SearchFunction, - class CompareFunction -> -inline -hipError_t binary_search(void * temporary_storage, - size_t& storage_size, - HaystackIterator haystack, - NeedlesIterator needles, - OutputIterator output, - size_t haystack_size, - size_t needles_size, - SearchFunction search_op, - CompareFunction compare_op, - hipStream_t stream, - bool debug_synchronous) +template +inline hipError_t binary_search(void* temporary_storage, + size_t& storage_size, + HaystackIterator haystack, + NeedlesIterator needles, + OutputIterator output, + size_t haystack_size, + size_t needles_size, + SearchFunction search_op, + CompareFunction compare_op, + hipStream_t stream, + bool debug_synchronous) { using value_type = typename std::iterator_traits::value_type; @@ -70,7 +68,9 @@ hipError_t binary_search(void * temporary_storage, return hipSuccess; } - return detail::transform_impl( + constexpr bool is_pointer + = false; // We do not use the optimization for transform when input is a pointer. + return detail::transform_impl( needles, output, needles_size, @@ -214,22 +214,19 @@ hipError_t lower_bound(void * temporary_storage, using value_type = typename std::iterator_traits::value_type; using output_type = typename std::iterator_traits::value_type; - using config - = std::conditional_t::value, - detail::default_config_for_lower_bound, - Config>; + using selector = detail::lower_bound_config_selector; - return detail::binary_search(temporary_storage, - storage_size, - haystack, - needles, - output, - haystack_size, - needles_size, - detail::lower_bound_search_op(), - compare_op, - stream, - debug_synchronous); + return detail::binary_search(temporary_storage, + storage_size, + haystack, + needles, + output, + haystack_size, + needles_size, + detail::lower_bound_search_op(), + compare_op, + stream, + debug_synchronous); } /// \brief Parallel primitive that uses binary search for computing an upper bound on a given ordered @@ -350,22 +347,19 @@ hipError_t upper_bound(void * temporary_storage, "Config must be a specialization of struct template upper_bound_config"); using value_type = typename std::iterator_traits::value_type; using output_type = typename std::iterator_traits::value_type; - using config - = std::conditional_t::value, - detail::default_config_for_upper_bound, - Config>; + using selector = detail::upper_bound_config_selector; - return detail::binary_search(temporary_storage, - storage_size, - haystack, - needles, - output, - haystack_size, - needles_size, - detail::upper_bound_search_op(), - compare_op, - stream, - debug_synchronous); + return detail::binary_search(temporary_storage, + storage_size, + haystack, + needles, + output, + haystack_size, + needles_size, + detail::upper_bound_search_op(), + compare_op, + stream, + debug_synchronous); } /// \brief Parallel primitive for performing a binary search (on a sorted range) of a given input. @@ -481,22 +475,19 @@ hipError_t binary_search(void * temporary_storage, "Config must be a specialization of struct template binary_search_config"); using value_type = typename std::iterator_traits::value_type; using output_type = typename std::iterator_traits::value_type; - using config - = std::conditional_t::value, - detail::default_config_for_binary_search, - Config>; + using selector = detail::binary_search_config_selector; - return detail::binary_search(temporary_storage, - storage_size, - haystack, - needles, - output, - haystack_size, - needles_size, - detail::binary_search_op(), - compare_op, - stream, - debug_synchronous); + return detail::binary_search(temporary_storage, + storage_size, + haystack, + needles, + output, + haystack_size, + needles_size, + detail::binary_search_op(), + compare_op, + stream, + debug_synchronous); } END_ROCPRIM_NAMESPACE diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_binary_search_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_binary_search_config.hpp index 3580830b88d..79c038ae52e 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_binary_search_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_binary_search_config.hpp @@ -39,67 +39,46 @@ namespace detail { template -struct default_config_for_binary_search -{}; - -template -struct default_config_for_upper_bound -{}; +struct binary_search_config_selector +{ + using targets = binary_search_targets; + using param_type = transform_config_params; -template -struct default_config_for_lower_bound -{}; + param_type params; -template -struct wrapped_transform_config, Unused, IsPointer> -{ - template - struct architecture_config - { - static constexpr transform_config_params params - = default_binary_search_config(Arch), Value, Output>{}; - }; + template + constexpr binary_search_config_selector(Target) + : params(binary_search_config_picker()) + {} }; -template -struct wrapped_transform_config, Unused, IsPointer> +template +struct lower_bound_config_selector { - template - struct architecture_config - { - static constexpr transform_config_params params - = default_upper_bound_config(Arch), Value, Output>{}; - }; + using targets = lower_bound_targets; + using param_type = transform_config_params; + + param_type params; + + template + constexpr lower_bound_config_selector(Target) + : params(lower_bound_config_picker()) + {} }; -template -struct wrapped_transform_config, Unused, IsPointer> +template +struct upper_bound_config_selector { - template - struct architecture_config - { - static constexpr transform_config_params params - = default_lower_bound_config(Arch), Value, Output>{}; - }; -}; + using targets = upper_bound_targets; + using param_type = transform_config_params; + + param_type params; -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template -template -constexpr transform_config_params - wrapped_transform_config, Unused, IsPointer>:: - architecture_config::params; -template -template -constexpr transform_config_params - wrapped_transform_config, Unused, IsPointer>:: - architecture_config::params; -template -template -constexpr transform_config_params - wrapped_transform_config, Unused, IsPointer>:: - architecture_config::params; -#endif // DOXYGEN_SHOULD_SKIP_THIS + template + constexpr upper_bound_config_selector(Target) + : params(upper_bound_config_picker()) + {} +}; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_transform.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_transform.hpp index 80e36fca6bf..b1c1679e9e3 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_transform.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_transform.hpp @@ -48,6 +48,7 @@ namespace detail template @@ -58,24 +59,23 @@ inline hipError_t transform_impl(InputIterator input, const hipStream_t stream, bool debug_synchronous) { + using input_type = typename std::iterator_traits::value_type; + using result_type = typename ::rocprim::invoke_result::type; + if(size == size_t(0)) { return hipSuccess; } - using input_type = typename std::iterator_traits::value_type; - using result_type = typename ::rocprim::invoke_result::type; + detail::target_arch target_arch; + ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); - using config = detail::wrapped_transform_config; + detail::gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); - detail::target_arch target_arch; - hipError_t result = detail::host_target_arch(stream, target_arch); - if(result != hipSuccess) - { - return result; - } - const detail::transform_config_params params - = detail::dispatch_target_arch(target_arch); + const target get_target(target_arch, target_gpu); + + const auto params = get_config(Config{}, get_target); const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_thread = params.kernel_config.items_per_thread; @@ -123,12 +123,12 @@ inline hipError_t transform_impl(InputIterator input, transform_op); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - transform_kernel, - current_blocks, - block_size, - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(get_target, + transform_kernel, + current_blocks, + block_size, + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("transform_kernel", current_size, start); } @@ -204,13 +204,17 @@ inline hipError_t transform(InputIterator input, constexpr bool is_pointer = std::is_pointer::value && std::is_pointer::value; - return detail::transform_impl( - input, - output, - size, - transform_op, - stream, - debug_synchronous); + using input_type = typename std::iterator_traits::value_type; + using selector = detail::transform_config_selector; + + return detail:: + transform_impl( + input, + output, + size, + transform_op, + stream, + debug_synchronous); } /// \brief Parallel device-level transform primitive for two inputs. diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_transform_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_transform_config.hpp index a74bc163fee..4327b545023 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_transform_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_transform_config.hpp @@ -39,58 +39,33 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -template -struct wrapped_transform_config +template +struct transform_config_selector { - static_assert(std::is_base_of::value, - "Config must be a specialization of struct template transform_config"); + using targets = std::conditional_t; + using param_type = transform_config_params; - template - struct architecture_config - { - static constexpr transform_config_params params = TransformConfig{}; - }; -}; - -template -struct wrapped_transform_config -{ - template - struct architecture_config - { - static constexpr transform_config_params params - = default_transform_pointer_config(Arch), Value>{}; - }; -}; + param_type params; -template -struct wrapped_transform_config -{ - template - struct architecture_config + template + constexpr param_type picker_helper() { - static constexpr transform_config_params params - = default_transform_config(Arch), Value>{}; - }; + // Different configs if it is a pointer. + if constexpr(IsPointer) + { + return transform_pointer_config_picker(); + } + else + { + return transform_config_picker(); + } + } + + template + constexpr transform_config_selector(Target) : params(picker_helper()) + {} }; -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template -template -constexpr transform_config_params - wrapped_transform_config::architecture_config::params; - -template -template -constexpr transform_config_params - wrapped_transform_config::architecture_config::params; - -template -template -constexpr transform_config_params - wrapped_transform_config::architecture_config::params; -#endif // DOXYGEN_SHOULD_SKIP_THIS - } // end namespace detail END_ROCPRIM_NAMESPACE From 74d868921db1411ba92af567a41e4b75db032a74 Mon Sep 17 00:00:00 2001 From: Nick Breed Date: Fri, 7 Nov 2025 10:07:19 +0000 Subject: [PATCH 05/26] Resolve "Update configs for new config system part 3" --- .../include/rocprim/device/config_types.hpp | 5 +- .../detail/config/device_partition_flag.hpp | 1119 ++-- .../config/device_partition_predicate.hpp | 1119 ++-- .../config/device_partition_three_way.hpp | 1119 ++-- .../config/device_partition_two_way_flag.hpp | 1119 ++-- .../device_partition_two_way_predicate.hpp | 1120 ++-- .../detail/config/device_select_flag.hpp | 1119 ++-- .../detail/config/device_select_predicate.hpp | 1119 ++-- .../config/device_select_predicated_flag.hpp | 5902 +++++++---------- .../detail/config/device_select_unique.hpp | 1119 ++-- .../config/device_select_unique_by_key.hpp | 5773 +++++++--------- .../device/detail/device_config_helper.hpp | 30 +- .../rocprim/device/device_partition.hpp | 84 +- .../device/device_partition_config.hpp | 311 +- .../device/device_segmented_radix_sort.hpp | 8 +- 15 files changed, 9077 insertions(+), 11989 deletions(-) diff --git a/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp b/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp index ab20e0922c6..879923c57dc 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp @@ -647,7 +647,10 @@ void trampoline_kernel(Kernel kernel) using ArchConfig = target_config2; #if !defined(ROCPRIM_TARGET_SPIRV) || ROCPRIM_TARGET_SPIRV == 0 - if constexpr(Target::i == device_target_arch()) + using Targets = typename Selector::targets; + // If the arch does not exist in the Targets it should run the arch for the most_common_config. + constexpr target device_arch_target = most_common_config(target(device_target_arch())); + if constexpr(Target::i == device_arch_target.i) #endif { kernel(ArchConfig{}); diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_partition_flag.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_partition_flag.hpp index 0c0a1b540f0..141b6bf1fd2 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_partition_flag.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_partition_flag.hpp @@ -40,608 +40,523 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_partition_flag_config : default_partition_config_base::type -{}; - -// Based on data_type = double -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 4> -{}; - -// Based on data_type = float -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<512, 5> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<512, 8> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 4> -{}; - -// Based on data_type = int -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 5> -{}; - -// Based on data_type = short -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<512, 12> -{}; - -// Based on data_type = int8_t -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 28> -{}; - -// Based on data_type = double -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 7> -{}; - -// Based on data_type = float -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<128, 12> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 20> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<512, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 7> -{}; - -// Based on data_type = int -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<128, 12> -{}; - -// Based on data_type = short -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<128, 19> -{}; - -// Based on data_type = int8_t -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<128, 28> -{}; - -// Based on data_type = double -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<256, 4> -{}; - -// Based on data_type = float -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<384, 6> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 12> -{}; - -// Based on data_type = int64_t -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<256, 4> -{}; - -// Based on data_type = int -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<384, 7> -{}; - -// Based on data_type = short -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 12> -{}; - -// Based on data_type = int8_t -template -struct default_partition_flag_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on data_type = double -template -struct default_partition_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<256, 7> -{}; - -// Based on data_type = float -template -struct default_partition_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 13> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 28> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<256, 7> -{}; - -// Based on data_type = int -template -struct default_partition_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 13> -{}; - -// Based on data_type = short -template -struct default_partition_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 28> -{}; - -// Based on data_type = int8_t -template -struct default_partition_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 20> -{}; - -// Based on data_type = double -template -struct default_partition_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<128, 7> -{}; - -// Based on data_type = float -template -struct default_partition_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<128, 12> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 28> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<128, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<128, 7> -{}; - -// Based on data_type = int -template -struct default_partition_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<128, 12> -{}; - -// Based on data_type = short -template -struct default_partition_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 28> -{}; - -// Based on data_type = int8_t -template -struct default_partition_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 28> -{}; - -// Based on data_type = double -template -struct default_partition_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<192, 4> -{}; - -// Based on data_type = float -template -struct default_partition_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 10> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 20> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 3> -{}; - -// Based on data_type = int64_t -template -struct default_partition_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<192, 4> -{}; - -// Based on data_type = int -template -struct default_partition_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 10> -{}; - -// Based on data_type = short -template -struct default_partition_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 20> -{}; - -// Based on data_type = int8_t -template -struct default_partition_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on data_type = double -template -struct default_partition_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<128, 7> -{}; - -// Based on data_type = float -template -struct default_partition_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<128, 12> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 28> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<128, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<128, 7> -{}; - -// Based on data_type = int -template -struct default_partition_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<128, 12> -{}; - -// Based on data_type = short -template -struct default_partition_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 28> -{}; - -// Based on data_type = int8_t -template -struct default_partition_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 28> -{}; - -// Based on data_type = double -template -struct default_partition_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 7> -{}; - -// Based on data_type = float -template -struct default_partition_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<512, 15> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<512, 28> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<384, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 7> -{}; - -// Based on data_type = int -template -struct default_partition_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 15> -{}; - -// Based on data_type = short -template -struct default_partition_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<512, 30> -{}; - -// Based on data_type = int8_t -template -struct default_partition_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<512, 24> -{}; +template +constexpr auto partition_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 5} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 5} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {512, 12} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 28} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {128, 12} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 20} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {128, 12} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {128, 19} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {128, 28} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {384, 6} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 12} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {384, 7} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 12} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 13} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 28} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 13} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 28} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 20} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {128, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {128, 12} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 28} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {128, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {128, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {128, 12} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 28} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 28} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 10} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 20} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 3} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 10} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 20} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {512, 28} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {512, 30} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {512, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + return partition_flag_config_picker< + comp_target, + data_type>(); +} + +// All the existing configs should be auto generated +using partition_flag_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_partition_predicate.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_partition_predicate.hpp index 80658c1166e..29bce571c57 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_partition_predicate.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_partition_predicate.hpp @@ -40,608 +40,523 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_partition_predicate_config : default_partition_config_base::type -{}; - -// Based on data_type = double -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 4> -{}; - -// Based on data_type = float -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<512, 4> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<512, 8> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 4> -{}; - -// Based on data_type = int -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 4> -{}; - -// Based on data_type = short -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 30> -{}; - -// Based on data_type = int8_t -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<128, 24> -{}; - -// Based on data_type = double -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 8> -{}; - -// Based on data_type = float -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<384, 9> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<384, 18> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<512, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 8> -{}; - -// Based on data_type = int -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<384, 9> -{}; - -// Based on data_type = short -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<128, 30> -{}; - -// Based on data_type = int8_t -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 28> -{}; - -// Based on data_type = double -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 4> -{}; - -// Based on data_type = float -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<384, 7> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<512, 24> -{}; - -// Based on data_type = int64_t -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 4> -{}; - -// Based on data_type = int -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<384, 7> -{}; - -// Based on data_type = short -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<512, 24> -{}; - -// Based on data_type = int8_t -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<512, 28> -{}; - -// Based on data_type = double -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<256, 7> -{}; - -// Based on data_type = float -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<192, 15> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 26> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<256, 7> -{}; - -// Based on data_type = int -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<192, 15> -{}; - -// Based on data_type = short -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 24> -{}; - -// Based on data_type = int8_t -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on data_type = double -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<128, 7> -{}; - -// Based on data_type = float -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 15> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 26> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<128, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<192, 7> -{}; - -// Based on data_type = int -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 15> -{}; - -// Based on data_type = short -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 24> -{}; - -// Based on data_type = int8_t -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on data_type = double -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<192, 4> -{}; - -// Based on data_type = float -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 10> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 20> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 3> -{}; - -// Based on data_type = int64_t -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<192, 4> -{}; - -// Based on data_type = int -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 10> -{}; - -// Based on data_type = short -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 20> -{}; - -// Based on data_type = int8_t -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<512, 24> -{}; - -// Based on data_type = double -template -struct default_partition_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<128, 7> -{}; - -// Based on data_type = float -template -struct default_partition_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 15> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 26> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<128, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<192, 7> -{}; - -// Based on data_type = int -template -struct default_partition_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 15> -{}; - -// Based on data_type = short -template -struct default_partition_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 24> -{}; - -// Based on data_type = int8_t -template -struct default_partition_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on data_type = double -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 7> -{}; - -// Based on data_type = float -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<512, 15> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<512, 30> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<384, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 7> -{}; - -// Based on data_type = int -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 15> -{}; - -// Based on data_type = short -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<512, 30> -{}; - -// Based on data_type = int8_t -template -struct default_partition_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<512, 23> -{}; +template +constexpr auto partition_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 30} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {128, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {384, 9} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {384, 18} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {384, 9} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {128, 30} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 28} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {384, 7} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {512, 24} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {384, 7} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {512, 24} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {512, 28} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {192, 15} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 26} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {192, 15} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {128, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 15} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 26} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {128, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {192, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 15} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 10} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 20} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 3} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 10} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 20} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {512, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {512, 30} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {512, 30} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {512, 23} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + return partition_predicate_config_picker< + comp_target, + data_type>(); +} + +// All the existing configs should be auto generated +using partition_predicate_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_partition_three_way.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_partition_three_way.hpp index eac8d66329e..53ababfdf51 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_partition_three_way.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_partition_three_way.hpp @@ -40,608 +40,523 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_partition_three_way_config : default_partition_config_base::type -{}; - -// Based on data_type = double -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 4> -{}; - -// Based on data_type = float -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<512, 6> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<512, 12> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 4> -{}; - -// Based on data_type = int -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 6> -{}; - -// Based on data_type = short -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 30> -{}; - -// Based on data_type = int8_t -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 28> -{}; - -// Based on data_type = double -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 8> -{}; - -// Based on data_type = float -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 10> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 17> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<512, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 8> -{}; - -// Based on data_type = int -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 9> -{}; - -// Based on data_type = short -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 18> -{}; - -// Based on data_type = int8_t -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 20> -{}; - -// Based on data_type = double -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<384, 4> -{}; - -// Based on data_type = float -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<512, 10> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<512, 26> -{}; - -// Based on data_type = int64_t -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<384, 4> -{}; - -// Based on data_type = int -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 10> -{}; - -// Based on data_type = short -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<512, 28> -{}; - -// Based on data_type = int8_t -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 28> -{}; - -// Based on data_type = double -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<256, 7> -{}; - -// Based on data_type = float -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 15> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 14> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<256, 7> -{}; - -// Based on data_type = int -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 15> -{}; - -// Based on data_type = short -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 15> -{}; - -// Based on data_type = int8_t -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 15> -{}; - -// Based on data_type = double -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<192, 6> -{}; - -// Based on data_type = float -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 15> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 14> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 3> -{}; - -// Based on data_type = int64_t -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<256, 7> -{}; - -// Based on data_type = int -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 15> -{}; - -// Based on data_type = short -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 14> -{}; - -// Based on data_type = int8_t -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 15> -{}; - -// Based on data_type = double -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<256, 6> -{}; - -// Based on data_type = float -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 11> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 16> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 3> -{}; - -// Based on data_type = int64_t -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<256, 6> -{}; - -// Based on data_type = int -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 11> -{}; - -// Based on data_type = short -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 16> -{}; - -// Based on data_type = int8_t -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 17> -{}; - -// Based on data_type = double -template -struct default_partition_three_way_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<192, 6> -{}; - -// Based on data_type = float -template -struct default_partition_three_way_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 15> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_three_way_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 14> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_three_way_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 3> -{}; - -// Based on data_type = int64_t -template -struct default_partition_three_way_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<256, 7> -{}; - -// Based on data_type = int -template -struct default_partition_three_way_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 15> -{}; - -// Based on data_type = short -template -struct default_partition_three_way_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 14> -{}; - -// Based on data_type = int8_t -template -struct default_partition_three_way_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 15> -{}; - -// Based on data_type = double -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 7> -{}; - -// Based on data_type = float -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<512, 15> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 18> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<384, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 7> -{}; - -// Based on data_type = int -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 15> -{}; - -// Based on data_type = short -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 18> -{}; - -// Based on data_type = int8_t -template -struct default_partition_three_way_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 18> -{}; +template +constexpr auto partition_three_way_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 6} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {512, 12} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 6} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 30} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 28} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_three_way_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 10} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 17} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 9} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 18} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 20} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_three_way_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 10} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {512, 26} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 10} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {512, 28} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 28} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_three_way_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 15} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 14} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 15} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 15} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 15} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_three_way_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {192, 6} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 15} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 14} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 3} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 15} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 14} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 15} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_three_way_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 11} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 16} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 3} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 11} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 16} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 17} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_three_way_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 18} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 18} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 18} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_three_way_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + return partition_three_way_config_picker< + comp_target, + data_type>(); +} + +// All the existing configs should be auto generated +using partition_three_way_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_partition_two_way_flag.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_partition_two_way_flag.hpp index 82dfb30f2d3..ac23fa8ec65 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_partition_two_way_flag.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_partition_two_way_flag.hpp @@ -40,608 +40,523 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_partition_two_way_flag_config : default_partition_config_base::type -{}; - -// Based on data_type = double -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 4> -{}; - -// Based on data_type = float -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<512, 4> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<512, 6> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 4> -{}; - -// Based on data_type = int -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 4> -{}; - -// Based on data_type = short -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<512, 12> -{}; - -// Based on data_type = int8_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<384, 20> -{}; - -// Based on data_type = double -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 8> -{}; - -// Based on data_type = float -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<384, 10> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<384, 18> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<512, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<128, 6> -{}; - -// Based on data_type = int -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<128, 11> -{}; - -// Based on data_type = short -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<128, 17> -{}; - -// Based on data_type = int8_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<128, 28> -{}; - -// Based on data_type = double -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<256, 4> -{}; - -// Based on data_type = float -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<384, 7> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<384, 12> -{}; - -// Based on data_type = int64_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<256, 4> -{}; - -// Based on data_type = int -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 6> -{}; - -// Based on data_type = short -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<384, 14> -{}; - -// Based on data_type = int8_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on data_type = double -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<256, 7> -{}; - -// Based on data_type = float -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 13> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 28> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<256, 7> -{}; - -// Based on data_type = int -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 13> -{}; - -// Based on data_type = short -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 28> -{}; - -// Based on data_type = int8_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 20> -{}; - -// Based on data_type = double -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<128, 7> -{}; - -// Based on data_type = float -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 13> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 28> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<128, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<128, 7> -{}; - -// Based on data_type = int -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 12> -{}; - -// Based on data_type = short -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 28> -{}; - -// Based on data_type = int8_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 28> -{}; - -// Based on data_type = double -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<192, 4> -{}; - -// Based on data_type = float -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 10> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 20> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 3> -{}; - -// Based on data_type = int64_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<192, 4> -{}; - -// Based on data_type = int -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 7> -{}; - -// Based on data_type = short -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 20> -{}; - -// Based on data_type = int8_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on data_type = double -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<128, 7> -{}; - -// Based on data_type = float -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 13> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 28> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<128, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<128, 7> -{}; - -// Based on data_type = int -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 12> -{}; - -// Based on data_type = short -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 28> -{}; - -// Based on data_type = int8_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 28> -{}; - -// Based on data_type = double -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 7> -{}; - -// Based on data_type = float -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<512, 15> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<512, 30> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<384, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 7> -{}; - -// Based on data_type = int -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 15> -{}; - -// Based on data_type = short -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<512, 28> -{}; - -// Based on data_type = int8_t -template -struct default_partition_two_way_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<512, 24> -{}; +template +constexpr auto partition_two_way_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {512, 6} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {512, 12} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {384, 20} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_two_way_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {384, 10} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {384, 18} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {128, 6} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {128, 11} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {128, 17} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {128, 28} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_two_way_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {384, 7} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {384, 12} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {384, 14} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_two_way_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 13} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 28} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 13} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 28} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 20} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_two_way_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {128, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 13} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 28} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {128, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {128, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 12} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 28} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 28} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_two_way_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 10} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 20} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 3} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 20} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_two_way_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {512, 30} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {512, 28} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {512, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_two_way_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + return partition_two_way_flag_config_picker< + comp_target, + data_type>(); +} + +// All the existing configs should be auto generated +using partition_two_way_flag_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_partition_two_way_predicate.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_partition_two_way_predicate.hpp index a6fdba90c1c..e2132d512b9 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_partition_two_way_predicate.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_partition_two_way_predicate.hpp @@ -40,609 +40,523 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_partition_two_way_predicate_config - : default_partition_config_base::type -{}; - -// Based on data_type = double -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 4> -{}; - -// Based on data_type = float -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<512, 4> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<512, 8> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 4> -{}; - -// Based on data_type = int -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 4> -{}; - -// Based on data_type = short -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<128, 30> -{}; - -// Based on data_type = int8_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on data_type = double -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 8> -{}; - -// Based on data_type = float -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<192, 9> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<192, 12> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<512, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 8> -{}; - -// Based on data_type = int -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<128, 8> -{}; - -// Based on data_type = short -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<128, 18> -{}; - -// Based on data_type = int8_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 28> -{}; - -// Based on data_type = double -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<256, 4> -{}; - -// Based on data_type = float -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<384, 6> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<384, 18> -{}; - -// Based on data_type = int64_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 4> -{}; - -// Based on data_type = int -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 8> -{}; - -// Based on data_type = short -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<384, 18> -{}; - -// Based on data_type = int8_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<384, 28> -{}; - -// Based on data_type = double -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<256, 7> -{}; - -// Based on data_type = float -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 15> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 26> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<256, 7> -{}; - -// Based on data_type = int -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<192, 15> -{}; - -// Based on data_type = short -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 24> -{}; - -// Based on data_type = int8_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on data_type = double -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<128, 7> -{}; - -// Based on data_type = float -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 15> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 26> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<128, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<128, 7> -{}; - -// Based on data_type = int -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 13> -{}; - -// Based on data_type = short -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 24> -{}; - -// Based on data_type = int8_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on data_type = double -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<128, 6> -{}; - -// Based on data_type = float -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 10> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 20> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<128, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<192, 4> -{}; - -// Based on data_type = int -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 10> -{}; - -// Based on data_type = short -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 28> -{}; - -// Based on data_type = int8_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on data_type = double -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<128, 7> -{}; - -// Based on data_type = float -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 15> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 26> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<128, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<128, 7> -{}; - -// Based on data_type = int -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 13> -{}; - -// Based on data_type = short -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 24> -{}; - -// Based on data_type = int8_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on data_type = double -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 7> -{}; - -// Based on data_type = float -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<512, 15> -{}; - -// Based on data_type = rocprim::half -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<512, 30> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<384, 4> -{}; - -// Based on data_type = int64_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 7> -{}; - -// Based on data_type = int -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 15> -{}; - -// Based on data_type = short -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<512, 30> -{}; - -// Based on data_type = int8_t -template -struct default_partition_two_way_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<512, 24> -{}; +template +constexpr auto partition_two_way_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {128, 30} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_two_way_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {192, 9} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {192, 12} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {128, 8} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {128, 18} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 28} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_two_way_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {384, 6} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {384, 18} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {384, 18} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {384, 28} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_two_way_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 15} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 26} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {192, 15} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_two_way_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {128, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 15} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 26} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {128, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {128, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 13} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_two_way_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {128, 6} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 10} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 20} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {128, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 10} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 28} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_two_way_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {512, 30} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {512, 30} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {512, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto partition_two_way_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + return partition_two_way_predicate_config_picker< + comp_target, + data_type>(); +} + +// All the existing configs should be auto generated +using partition_two_way_predicate_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_select_flag.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_select_flag.hpp index 7ab24f63a9a..95d57ce9a55 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_select_flag.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_select_flag.hpp @@ -40,608 +40,523 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_select_flag_config : default_partition_config_base::type -{}; - -// Based on data_type = double -template -struct default_select_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 6> -{}; - -// Based on data_type = float -template -struct default_select_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<512, 6> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<512, 16> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_select_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 4> -{}; - -// Based on data_type = int64_t -template -struct default_select_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 6> -{}; - -// Based on data_type = int -template -struct default_select_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 6> -{}; - -// Based on data_type = short -template -struct default_select_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 24> -{}; - -// Based on data_type = int8_t -template -struct default_select_flag_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<512, 24> -{}; - -// Based on data_type = double -template -struct default_select_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<128, 6> -{}; - -// Based on data_type = float -template -struct default_select_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<128, 12> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<512, 20> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_select_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<512, 4> -{}; - -// Based on data_type = int64_t -template -struct default_select_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<128, 6> -{}; - -// Based on data_type = int -template -struct default_select_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<128, 12> -{}; - -// Based on data_type = short -template -struct default_select_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<128, 24> -{}; - -// Based on data_type = int8_t -template -struct default_select_flag_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 28> -{}; - -// Based on data_type = double -template -struct default_select_flag_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<256, 5> -{}; - -// Based on data_type = float -template -struct default_select_flag_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<384, 7> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_flag_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 12> -{}; - -// Based on data_type = int64_t -template -struct default_select_flag_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<256, 4> -{}; - -// Based on data_type = int -template -struct default_select_flag_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<192, 8> -{}; - -// Based on data_type = short -template -struct default_select_flag_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 12> -{}; - -// Based on data_type = int8_t -template -struct default_select_flag_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<384, 28> -{}; - -// Based on data_type = double -template -struct default_select_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<256, 7> -{}; - -// Based on data_type = float -template -struct default_select_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 13> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 16> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_select_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 3> -{}; - -// Based on data_type = int64_t -template -struct default_select_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<256, 7> -{}; - -// Based on data_type = int -template -struct default_select_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 13> -{}; - -// Based on data_type = short -template -struct default_select_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 24> -{}; - -// Based on data_type = int8_t -template -struct default_select_flag_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on data_type = double -template -struct default_select_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<128, 7> -{}; - -// Based on data_type = float -template -struct default_select_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<512, 12> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 16> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_select_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<128, 4> -{}; - -// Based on data_type = int64_t -template -struct default_select_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<128, 7> -{}; - -// Based on data_type = int -template -struct default_select_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 12> -{}; - -// Based on data_type = short -template -struct default_select_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 24> -{}; - -// Based on data_type = int8_t -template -struct default_select_flag_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on data_type = double -template -struct default_select_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<192, 4> -{}; - -// Based on data_type = float -template -struct default_select_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 7> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 20> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_select_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 3> -{}; - -// Based on data_type = int64_t -template -struct default_select_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<192, 4> -{}; - -// Based on data_type = int -template -struct default_select_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 7> -{}; - -// Based on data_type = short -template -struct default_select_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 14> -{}; - -// Based on data_type = int8_t -template -struct default_select_flag_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on data_type = double -template -struct default_select_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<128, 7> -{}; - -// Based on data_type = float -template -struct default_select_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<512, 12> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 16> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_select_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<128, 4> -{}; - -// Based on data_type = int64_t -template -struct default_select_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<128, 7> -{}; - -// Based on data_type = int -template -struct default_select_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 12> -{}; - -// Based on data_type = short -template -struct default_select_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 24> -{}; - -// Based on data_type = int8_t -template -struct default_select_flag_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on data_type = double -template -struct default_select_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 7> -{}; - -// Based on data_type = float -template -struct default_select_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<512, 15> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<512, 28> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_select_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<384, 4> -{}; - -// Based on data_type = int64_t -template -struct default_select_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 7> -{}; - -// Based on data_type = int -template -struct default_select_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 15> -{}; - -// Based on data_type = short -template -struct default_select_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<512, 28> -{}; - -// Based on data_type = int8_t -template -struct default_select_flag_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<512, 24> -{}; +template +constexpr auto select_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 6} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 6} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {512, 16} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 6} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 6} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {512, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {128, 6} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {128, 12} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {512, 20} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {128, 6} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {128, 12} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {128, 24} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 28} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 5} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {384, 7} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 12} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {192, 8} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 12} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {384, 28} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 13} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 16} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 3} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 13} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {128, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 12} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 16} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {128, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {128, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 12} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 20} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 3} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 14} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {512, 28} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {512, 28} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {512, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + return select_flag_config_picker< + comp_target, + data_type>(); +} + +// All the existing configs should be auto generated +using select_flag_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_select_predicate.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_select_predicate.hpp index 9efc5124b5e..c43ce0896f4 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_select_predicate.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_select_predicate.hpp @@ -40,608 +40,523 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_select_predicate_config : default_partition_config_base::type -{}; - -// Based on data_type = double -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 4> -{}; - -// Based on data_type = float -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<512, 6> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 30> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 4> -{}; - -// Based on data_type = int64_t -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 4> -{}; - -// Based on data_type = int -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 8> -{}; - -// Based on data_type = short -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 30> -{}; - -// Based on data_type = int8_t -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<512, 24> -{}; - -// Based on data_type = double -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<384, 6> -{}; - -// Based on data_type = float -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<128, 14> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<192, 22> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<512, 4> -{}; - -// Based on data_type = int64_t -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<384, 6> -{}; - -// Based on data_type = int -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<128, 14> -{}; - -// Based on data_type = short -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<128, 30> -{}; - -// Based on data_type = int8_t -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 28> -{}; - -// Based on data_type = double -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<256, 4> -{}; - -// Based on data_type = float -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<384, 7> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<384, 18> -{}; - -// Based on data_type = int64_t -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<256, 4> -{}; - -// Based on data_type = int -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<384, 7> -{}; - -// Based on data_type = short -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<384, 18> -{}; - -// Based on data_type = int8_t -template -struct default_select_predicate_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<384, 28> -{}; - -// Based on data_type = double -template -struct default_select_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<256, 7> -{}; - -// Based on data_type = float -template -struct default_select_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 15> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 24> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_select_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 3> -{}; - -// Based on data_type = int64_t -template -struct default_select_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<256, 7> -{}; - -// Based on data_type = int -template -struct default_select_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<192, 15> -{}; - -// Based on data_type = short -template -struct default_select_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 26> -{}; - -// Based on data_type = int8_t -template -struct default_select_predicate_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on data_type = double -template -struct default_select_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<128, 7> -{}; - -// Based on data_type = float -template -struct default_select_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 15> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 24> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_select_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 3> -{}; - -// Based on data_type = int64_t -template -struct default_select_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<128, 7> -{}; - -// Based on data_type = int -template -struct default_select_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 14> -{}; - -// Based on data_type = short -template -struct default_select_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 26> -{}; - -// Based on data_type = int8_t -template -struct default_select_predicate_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on data_type = double -template -struct default_select_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<128, 6> -{}; - -// Based on data_type = float -template -struct default_select_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<512, 10> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 20> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_select_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 3> -{}; - -// Based on data_type = int64_t -template -struct default_select_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 5> -{}; - -// Based on data_type = int -template -struct default_select_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 10> -{}; - -// Based on data_type = short -template -struct default_select_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 20> -{}; - -// Based on data_type = int8_t -template -struct default_select_predicate_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on data_type = double -template -struct default_select_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<128, 7> -{}; - -// Based on data_type = float -template -struct default_select_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 15> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 24> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_select_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 3> -{}; - -// Based on data_type = int64_t -template -struct default_select_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<128, 7> -{}; - -// Based on data_type = int -template -struct default_select_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 14> -{}; - -// Based on data_type = short -template -struct default_select_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 26> -{}; - -// Based on data_type = int8_t -template -struct default_select_predicate_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on data_type = double -template -struct default_select_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 7> -{}; - -// Based on data_type = float -template -struct default_select_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<512, 15> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<512, 30> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_select_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<384, 4> -{}; - -// Based on data_type = int64_t -template -struct default_select_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 7> -{}; - -// Based on data_type = int -template -struct default_select_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 15> -{}; - -// Based on data_type = short -template -struct default_select_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<512, 30> -{}; - -// Based on data_type = int8_t -template -struct default_select_predicate_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<512, 24> -{}; +template +constexpr auto select_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 6} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 30} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 30} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {512, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {384, 6} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {128, 14} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {192, 22} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {384, 6} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {128, 14} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {128, 30} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 28} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {384, 7} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {384, 18} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {384, 7} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {384, 18} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {384, 28} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 15} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 24} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 3} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {192, 15} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 26} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {128, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 15} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 24} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 3} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {128, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 14} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 26} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {128, 6} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 10} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 20} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 3} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 5} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 10} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 20} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {512, 30} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {512, 30} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {512, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_predicate_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + return select_predicate_config_picker< + comp_target, + data_type>(); +} + +// All the existing configs should be auto generated +using select_predicate_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_select_predicated_flag.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_select_predicated_flag.hpp index 1b8faafb244..63f1ecf26a4 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_select_predicated_flag.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_select_predicated_flag.hpp @@ -40,3511 +40,2401 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_select_predicated_flag_config : default_partition_config_base::type -{}; - -// Based on data_type = double, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 16) - && (sizeof(flag_type) > 8))>> : select_config<256, 4> -{}; - -// Based on data_type = double, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<512, 4> -{}; - -// Based on data_type = double, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<512, 4> -{}; - -// Based on data_type = double, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<512, 4> -{}; - -// Based on data_type = double, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))>> - : select_config<512, 4> -{}; - -// Based on data_type = float, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 16) - && (sizeof(flag_type) > 8))>> : select_config<256, 4> -{}; - -// Based on data_type = float, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<512, 4> -{}; - -// Based on data_type = float, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<512, 4> -{}; - -// Based on data_type = float, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<512, 4> -{}; - -// Based on data_type = float, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))>> - : select_config<512, 4> -{}; - -// Based on data_type = rocprim::half, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<256, 4> -{}; - -// Based on data_type = rocprim::half, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<512, 8> -{}; - -// Based on data_type = rocprim::half, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<512, 8> -{}; - -// Based on data_type = rocprim::half, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<192, 25> -{}; - -// Based on data_type = rocprim::half, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 1))>> : select_config<512, 16> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<256, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<256, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<256, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<256, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 1))>> : select_config<256, 4> -{}; - -// Based on data_type = int64_t, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<256, 4> -{}; - -// Based on data_type = int64_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<512, 4> -{}; - -// Based on data_type = int64_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<512, 4> -{}; - -// Based on data_type = int64_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<512, 4> -{}; - -// Based on data_type = int64_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 1))>> : select_config<512, 4> -{}; - -// Based on data_type = int, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<256, 4> -{}; - -// Based on data_type = int, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<512, 4> -{}; - -// Based on data_type = int, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<512, 4> -{}; - -// Based on data_type = int, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<512, 4> -{}; - -// Based on data_type = int, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 1))>> : select_config<512, 4> -{}; - -// Based on data_type = short, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<256, 4> -{}; - -// Based on data_type = short, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<512, 4> -{}; - -// Based on data_type = short, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<512, 8> -{}; - -// Based on data_type = short, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<512, 8> -{}; - -// Based on data_type = short, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 1))>> : select_config<512, 16> -{}; - -// Based on data_type = int8_t, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 16) - && (sizeof(flag_type) > 8))>> : select_config<256, 4> -{}; - -// Based on data_type = int8_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<512, 4> -{}; - -// Based on data_type = int8_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<512, 8> -{}; - -// Based on data_type = int8_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<512, 16> -{}; - -// Based on data_type = int8_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1030), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 1))>> - : select_config<512, 16> -{}; - -// Based on data_type = double, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 16) - && (sizeof(flag_type) > 8))>> : select_config<512, 4> -{}; - -// Based on data_type = double, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<512, 4> -{}; - -// Based on data_type = double, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<192, 4> -{}; - -// Based on data_type = double, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<512, 8> -{}; - -// Based on data_type = double, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))>> - : select_config<512, 8> -{}; - -// Based on data_type = float, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 16) - && (sizeof(flag_type) > 8))>> : select_config<512, 4> -{}; - -// Based on data_type = float, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<512, 8> -{}; - -// Based on data_type = float, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<192, 6> -{}; - -// Based on data_type = float, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<192, 8> -{}; - -// Based on data_type = float, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))>> - : select_config<192, 12> -{}; - -// Based on data_type = rocprim::half, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<512, 4> -{}; - -// Based on data_type = rocprim::half, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<512, 8> -{}; - -// Based on data_type = rocprim::half, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<512, 16> -{}; - -// Based on data_type = rocprim::half, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<512, 16> -{}; - -// Based on data_type = rocprim::half, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 1))>> : select_config<192, 22> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<512, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<512, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<512, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<512, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 1))>> : select_config<512, 4> -{}; - -// Based on data_type = int64_t, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<512, 4> -{}; - -// Based on data_type = int64_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<512, 4> -{}; - -// Based on data_type = int64_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<192, 4> -{}; - -// Based on data_type = int64_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<512, 8> -{}; - -// Based on data_type = int64_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 1))>> : select_config<512, 8> -{}; - -// Based on data_type = int, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<512, 4> -{}; - -// Based on data_type = int, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<512, 8> -{}; - -// Based on data_type = int, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<192, 6> -{}; - -// Based on data_type = int, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<192, 8> -{}; - -// Based on data_type = int, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 1))>> : select_config<192, 12> -{}; - -// Based on data_type = short, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<512, 4> -{}; - -// Based on data_type = short, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<512, 8> -{}; - -// Based on data_type = short, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<512, 16> -{}; - -// Based on data_type = short, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<384, 18> -{}; - -// Based on data_type = short, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 1))>> : select_config<512, 20> -{}; - -// Based on data_type = int8_t, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 16) - && (sizeof(flag_type) > 8))>> : select_config<512, 4> -{}; - -// Based on data_type = int8_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<512, 8> -{}; - -// Based on data_type = int8_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<512, 16> -{}; - -// Based on data_type = int8_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<512, 20> -{}; - -// Based on data_type = int8_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1100), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 1))>> - : select_config<512, 20> -{}; - -// Based on data_type = double, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<192, 4> -{}; - -// Based on data_type = double, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<256, 4> -{}; - -// Based on data_type = double, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<256, 4> -{}; - -// Based on data_type = double, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))>> - : select_config<256, 4> -{}; - -// Based on data_type = float, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<384, 4> -{}; - -// Based on data_type = float, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<384, 4> -{}; - -// Based on data_type = float, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<384, 6> -{}; - -// Based on data_type = float, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))>> - : select_config<384, 7> -{}; - -// Based on data_type = rocprim::half, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<256, 4> -{}; - -// Based on data_type = rocprim::half, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<384, 8> -{}; - -// Based on data_type = rocprim::half, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<384, 12> -{}; - -// Based on data_type = rocprim::half, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 1))>> : select_config<384, 18> -{}; - -// Based on data_type = int64_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<192, 4> -{}; - -// Based on data_type = int64_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<256, 4> -{}; - -// Based on data_type = int64_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<256, 4> -{}; - -// Based on data_type = int64_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 1))>> : select_config<256, 4> -{}; - -// Based on data_type = int, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<384, 4> -{}; - -// Based on data_type = int, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<384, 4> -{}; - -// Based on data_type = int, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<384, 6> -{}; - -// Based on data_type = int, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 1))>> : select_config<384, 7> -{}; - -// Based on data_type = short, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<256, 4> -{}; - -// Based on data_type = short, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<384, 8> -{}; - -// Based on data_type = short, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<384, 12> -{}; - -// Based on data_type = short, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 1))>> : select_config<384, 18> -{}; - -// Based on data_type = int8_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<256, 8> -{}; - -// Based on data_type = int8_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<512, 16> -{}; - -// Based on data_type = int8_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<512, 16> -{}; - -// Based on data_type = int8_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx1200), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 1))>> - : select_config<256, 24> -{}; - -// Based on data_type = double, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 16) - && (sizeof(flag_type) > 8))>> : select_config<192, 4> -{}; - -// Based on data_type = double, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<192, 7> -{}; - -// Based on data_type = double, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<192, 7> -{}; - -// Based on data_type = double, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<256, 7> -{}; - -// Based on data_type = double, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))>> - : select_config<256, 7> -{}; - -// Based on data_type = float, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 16) - && (sizeof(flag_type) > 8))>> : select_config<192, 4> -{}; - -// Based on data_type = float, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<256, 6> -{}; - -// Based on data_type = float, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<256, 11> -{}; - -// Based on data_type = float, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<256, 10> -{}; - -// Based on data_type = float, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))>> - : select_config<256, 11> -{}; - -// Based on data_type = rocprim::half, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on data_type = rocprim::half, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<256, 6> -{}; - -// Based on data_type = rocprim::half, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<256, 12> -{}; - -// Based on data_type = rocprim::half, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<256, 18> -{}; - -// Based on data_type = rocprim::half, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 1))>> : select_config<256, 18> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<256, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<128, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<192, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<192, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 1))>> : select_config<256, 3> -{}; - -// Based on data_type = int64_t, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on data_type = int64_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<192, 7> -{}; - -// Based on data_type = int64_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<192, 7> -{}; - -// Based on data_type = int64_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<256, 7> -{}; - -// Based on data_type = int64_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 1))>> : select_config<256, 7> -{}; - -// Based on data_type = int, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on data_type = int, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<256, 6> -{}; - -// Based on data_type = int, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<256, 11> -{}; - -// Based on data_type = int, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<256, 10> -{}; - -// Based on data_type = int, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 1))>> : select_config<256, 11> -{}; - -// Based on data_type = short, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on data_type = short, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<256, 6> -{}; - -// Based on data_type = short, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<256, 12> -{}; - -// Based on data_type = short, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<256, 18> -{}; - -// Based on data_type = short, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 1))>> : select_config<256, 18> -{}; - -// Based on data_type = int8_t, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 16) - && (sizeof(flag_type) > 8))>> : select_config<192, 4> -{}; - -// Based on data_type = int8_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<256, 6> -{}; - -// Based on data_type = int8_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<256, 11> -{}; - -// Based on data_type = int8_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<256, 24> -{}; - -// Based on data_type = int8_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx906), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 1))>> - : select_config<192, 24> -{}; - -// Based on data_type = double, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 16) - && (sizeof(flag_type) > 8))>> : select_config<192, 4> -{}; - -// Based on data_type = double, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<128, 6> -{}; - -// Based on data_type = double, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<128, 6> -{}; - -// Based on data_type = double, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<128, 6> -{}; - -// Based on data_type = double, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))>> - : select_config<128, 7> -{}; - -// Based on data_type = float, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 16) - && (sizeof(flag_type) > 8))>> : select_config<192, 4> -{}; - -// Based on data_type = float, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<256, 6> -{}; - -// Based on data_type = float, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<128, 11> -{}; - -// Based on data_type = float, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<256, 10> -{}; - -// Based on data_type = float, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))>> - : select_config<256, 11> -{}; - -// Based on data_type = rocprim::half, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on data_type = rocprim::half, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<256, 6> -{}; - -// Based on data_type = rocprim::half, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<256, 12> -{}; - -// Based on data_type = rocprim::half, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<512, 18> -{}; - -// Based on data_type = rocprim::half, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<256, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<128, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<128, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<128, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 1))>> : select_config<128, 4> -{}; - -// Based on data_type = int64_t, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on data_type = int64_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<128, 6> -{}; - -// Based on data_type = int64_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<128, 6> -{}; - -// Based on data_type = int64_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<128, 6> -{}; - -// Based on data_type = int64_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 1))>> : select_config<128, 7> -{}; - -// Based on data_type = int, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on data_type = int, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<256, 6> -{}; - -// Based on data_type = int, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<128, 11> -{}; - -// Based on data_type = int, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<256, 10> -{}; - -// Based on data_type = int, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 1))>> : select_config<256, 11> -{}; - -// Based on data_type = short, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on data_type = short, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<256, 6> -{}; - -// Based on data_type = short, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<256, 12> -{}; - -// Based on data_type = short, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<512, 18> -{}; - -// Based on data_type = short, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 1))>> : select_config<256, 18> -{}; - -// Based on data_type = int8_t, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 16) - && (sizeof(flag_type) > 8))>> : select_config<192, 4> -{}; - -// Based on data_type = int8_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<256, 6> -{}; - -// Based on data_type = int8_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<256, 11> -{}; - -// Based on data_type = int8_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<256, 24> -{}; - -// Based on data_type = int8_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx908), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 1))>> - : select_config<256, 24> -{}; - -// Based on data_type = double, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 16) - && (sizeof(flag_type) > 8))>> : select_config<128, 4> -{}; - -// Based on data_type = double, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<192, 4> -{}; - -// Based on data_type = double, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<512, 5> -{}; - -// Based on data_type = double, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<192, 4> -{}; - -// Based on data_type = double, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))>> - : select_config<192, 4> -{}; - -// Based on data_type = float, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 16) - && (sizeof(flag_type) > 8))>> : select_config<192, 4> -{}; - -// Based on data_type = float, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<192, 5> -{}; - -// Based on data_type = float, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<256, 6> -{}; - -// Based on data_type = float, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<192, 10> -{}; - -// Based on data_type = float, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))>> - : select_config<256, 7> -{}; - -// Based on data_type = rocprim::half, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on data_type = rocprim::half, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<256, 6> -{}; - -// Based on data_type = rocprim::half, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<256, 10> -{}; - -// Based on data_type = rocprim::half, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<256, 20> -{}; - -// Based on data_type = rocprim::half, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 1))>> : select_config<256, 20> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<128, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<128, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<128, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<128, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 1))>> : select_config<128, 4> -{}; - -// Based on data_type = int64_t, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<128, 4> -{}; - -// Based on data_type = int64_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<192, 4> -{}; - -// Based on data_type = int64_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<192, 4> -{}; - -// Based on data_type = int64_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<192, 4> -{}; - -// Based on data_type = int64_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 1))>> : select_config<192, 5> -{}; - -// Based on data_type = int, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on data_type = int, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<192, 5> -{}; - -// Based on data_type = int, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<256, 6> -{}; - -// Based on data_type = int, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<192, 10> -{}; - -// Based on data_type = int, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 1))>> : select_config<192, 10> -{}; - -// Based on data_type = short, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on data_type = short, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<256, 6> -{}; - -// Based on data_type = short, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<256, 10> -{}; - -// Based on data_type = short, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<256, 20> -{}; - -// Based on data_type = short, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 1))>> : select_config<256, 20> -{}; - -// Based on data_type = int8_t, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 16) - && (sizeof(flag_type) > 8))>> : select_config<192, 4> -{}; - -// Based on data_type = int8_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<192, 5> -{}; - -// Based on data_type = int8_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<256, 12> -{}; - -// Based on data_type = int8_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<192, 20> -{}; - -// Based on data_type = int8_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx90a), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 1))>> - : select_config<256, 24> -{}; - -// Based on data_type = double, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 16) - && (sizeof(flag_type) > 8))>> : select_config<192, 4> -{}; - -// Based on data_type = double, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<128, 6> -{}; - -// Based on data_type = double, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<128, 6> -{}; - -// Based on data_type = double, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<128, 6> -{}; - -// Based on data_type = double, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))>> - : select_config<128, 7> -{}; - -// Based on data_type = float, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 16) - && (sizeof(flag_type) > 8))>> : select_config<192, 4> -{}; - -// Based on data_type = float, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<256, 6> -{}; - -// Based on data_type = float, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<128, 11> -{}; - -// Based on data_type = float, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<256, 10> -{}; - -// Based on data_type = float, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))>> - : select_config<256, 11> -{}; - -// Based on data_type = rocprim::half, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on data_type = rocprim::half, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<256, 6> -{}; - -// Based on data_type = rocprim::half, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<256, 12> -{}; - -// Based on data_type = rocprim::half, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<512, 18> -{}; - -// Based on data_type = rocprim::half, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<256, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<128, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<128, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<128, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 1))>> : select_config<128, 4> -{}; - -// Based on data_type = int64_t, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on data_type = int64_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<128, 6> -{}; - -// Based on data_type = int64_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<128, 6> -{}; - -// Based on data_type = int64_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<128, 6> -{}; - -// Based on data_type = int64_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 1))>> : select_config<128, 7> -{}; - -// Based on data_type = int, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on data_type = int, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<256, 6> -{}; - -// Based on data_type = int, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<128, 11> -{}; - -// Based on data_type = int, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<256, 10> -{}; - -// Based on data_type = int, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 1))>> : select_config<256, 11> -{}; - -// Based on data_type = short, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on data_type = short, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<256, 6> -{}; - -// Based on data_type = short, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<256, 12> -{}; - -// Based on data_type = short, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<512, 18> -{}; - -// Based on data_type = short, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 1))>> : select_config<256, 18> -{}; - -// Based on data_type = int8_t, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 16) - && (sizeof(flag_type) > 8))>> : select_config<192, 4> -{}; - -// Based on data_type = int8_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<256, 6> -{}; - -// Based on data_type = int8_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<256, 11> -{}; - -// Based on data_type = int8_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<256, 24> -{}; - -// Based on data_type = int8_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::unknown), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 1))>> - : select_config<256, 24> -{}; - -// Based on data_type = double, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 16) - && (sizeof(flag_type) > 8))>> : select_config<384, 4> -{}; - -// Based on data_type = double, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<512, 7> -{}; - -// Based on data_type = double, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<512, 7> -{}; - -// Based on data_type = double, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<512, 7> -{}; - -// Based on data_type = double, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))>> - : select_config<512, 7> -{}; - -// Based on data_type = float, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 16) - && (sizeof(flag_type) > 8))>> : select_config<384, 4> -{}; - -// Based on data_type = float, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<512, 7> -{}; - -// Based on data_type = float, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<512, 15> -{}; - -// Based on data_type = float, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<512, 15> -{}; - -// Based on data_type = float, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))>> - : select_config<512, 15> -{}; - -// Based on data_type = rocprim::half, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<384, 4> -{}; - -// Based on data_type = rocprim::half, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<512, 7> -{}; - -// Based on data_type = rocprim::half, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<512, 14> -{}; - -// Based on data_type = rocprim::half, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<512, 30> -{}; - -// Based on data_type = rocprim::half, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) - && (sizeof(flag_type) <= 1))>> : select_config<512, 28> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<384, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<384, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<384, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<384, 4> -{}; - -// Based on data_type = rocprim::int128_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8) - && (sizeof(flag_type) <= 1))>> : select_config<384, 4> -{}; - -// Based on data_type = int64_t, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<384, 4> -{}; - -// Based on data_type = int64_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<512, 7> -{}; - -// Based on data_type = int64_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<512, 7> -{}; - -// Based on data_type = int64_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<512, 7> -{}; - -// Based on data_type = int64_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4) - && (sizeof(flag_type) <= 1))>> : select_config<512, 7> -{}; - -// Based on data_type = int, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<384, 4> -{}; - -// Based on data_type = int, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<512, 7> -{}; - -// Based on data_type = int, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<512, 15> -{}; - -// Based on data_type = int, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<512, 14> -{}; - -// Based on data_type = int, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2) - && (sizeof(flag_type) <= 1))>> : select_config<512, 15> -{}; - -// Based on data_type = short, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))>> - : select_config<384, 4> -{}; - -// Based on data_type = short, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))>> - : select_config<512, 7> -{}; - -// Based on data_type = short, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))>> - : select_config<512, 14> -{}; - -// Based on data_type = short, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))>> - : select_config<512, 30> -{}; - -// Based on data_type = short, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1) - && (sizeof(flag_type) <= 1))>> : select_config<512, 28> -{}; - -// Based on data_type = int8_t, flag_type = rocprim::int128_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 16) - && (sizeof(flag_type) > 8))>> : select_config<384, 4> -{}; - -// Based on data_type = int8_t, flag_type = int64_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 8) - && (sizeof(flag_type) > 4))>> : select_config<512, 7> -{}; - -// Based on data_type = int8_t, flag_type = int -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 4) - && (sizeof(flag_type) > 2))>> : select_config<512, 15> -{}; - -// Based on data_type = int8_t, flag_type = short -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 2) - && (sizeof(flag_type) > 1))>> : select_config<512, 28> -{}; - -// Based on data_type = int8_t, flag_type = int8_t -template -struct default_select_predicated_flag_config< - static_cast(target_arch::gfx942), - data_type, - flag_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1) && (sizeof(flag_type) <= 1))>> - : select_config<512, 24> -{}; +template +constexpr auto select_predicated_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double, flag_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = double, flag_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = double, flag_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = double, flag_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = double, flag_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = float, flag_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = float, flag_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = float, flag_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = float, flag_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = float, flag_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = rocprim::half, flag_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = rocprim::half, flag_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = rocprim::half, flag_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = rocprim::half, flag_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {192, 25} + }; + } + // Based on data_type = rocprim::half, flag_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {512, 16} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int64_t, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int64_t, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int64_t, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int64_t, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int64_t, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = short, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = short, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = short, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = short, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = short, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {512, 16} + }; + } + // Based on data_type = int8_t, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int8_t, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int8_t, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = int8_t, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {512, 16} + }; + } + // Based on data_type = int8_t, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {512, 16} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_predicated_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double, flag_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = double, flag_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = double, flag_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = double, flag_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = double, flag_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = float, flag_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = float, flag_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = float, flag_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {192, 6} + }; + } + // Based on data_type = float, flag_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {192, 8} + }; + } + // Based on data_type = float, flag_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {192, 12} + }; + } + // Based on data_type = rocprim::half, flag_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = rocprim::half, flag_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = rocprim::half, flag_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {512, 16} + }; + } + // Based on data_type = rocprim::half, flag_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {512, 16} + }; + } + // Based on data_type = rocprim::half, flag_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {192, 22} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int64_t, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int64_t, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int64_t, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = int64_t, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = int64_t, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = int, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = int, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {192, 6} + }; + } + // Based on data_type = int, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {192, 8} + }; + } + // Based on data_type = int, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {192, 12} + }; + } + // Based on data_type = short, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = short, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = short, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {512, 16} + }; + } + // Based on data_type = short, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {384, 18} + }; + } + // Based on data_type = short, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {512, 20} + }; + } + // Based on data_type = int8_t, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int8_t, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = int8_t, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {512, 16} + }; + } + // Based on data_type = int8_t, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {512, 20} + }; + } + // Based on data_type = int8_t, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {512, 20} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_predicated_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double, flag_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = double, flag_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = double, flag_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = double, flag_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = float, flag_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = float, flag_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = float, flag_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {384, 6} + }; + } + // Based on data_type = float, flag_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {384, 7} + }; + } + // Based on data_type = rocprim::half, flag_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = rocprim::half, flag_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {384, 8} + }; + } + // Based on data_type = rocprim::half, flag_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {384, 12} + }; + } + // Based on data_type = rocprim::half, flag_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {384, 18} + }; + } + // Based on data_type = int64_t, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = int64_t, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int64_t, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int64_t, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = int, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = int, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {384, 6} + }; + } + // Based on data_type = int, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {384, 7} + }; + } + // Based on data_type = short, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = short, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {384, 8} + }; + } + // Based on data_type = short, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {384, 12} + }; + } + // Based on data_type = short, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {384, 18} + }; + } + // Based on data_type = int8_t, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {256, 8} + }; + } + // Based on data_type = int8_t, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {512, 16} + }; + } + // Based on data_type = int8_t, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {512, 16} + }; + } + // Based on data_type = int8_t, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_predicated_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double, flag_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = double, flag_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {192, 7} + }; + } + // Based on data_type = double, flag_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {192, 7} + }; + } + // Based on data_type = double, flag_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = double, flag_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = float, flag_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = float, flag_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on data_type = float, flag_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {256, 11} + }; + } + // Based on data_type = float, flag_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {256, 10} + }; + } + // Based on data_type = float, flag_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {256, 11} + }; + } + // Based on data_type = rocprim::half, flag_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = rocprim::half, flag_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on data_type = rocprim::half, flag_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {256, 12} + }; + } + // Based on data_type = rocprim::half, flag_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {256, 18} + }; + } + // Based on data_type = rocprim::half, flag_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {256, 18} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {128, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {256, 3} + }; + } + // Based on data_type = int64_t, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = int64_t, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {192, 7} + }; + } + // Based on data_type = int64_t, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {192, 7} + }; + } + // Based on data_type = int64_t, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = int64_t, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = int, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = int, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on data_type = int, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {256, 11} + }; + } + // Based on data_type = int, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {256, 10} + }; + } + // Based on data_type = int, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {256, 11} + }; + } + // Based on data_type = short, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = short, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on data_type = short, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {256, 12} + }; + } + // Based on data_type = short, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {256, 18} + }; + } + // Based on data_type = short, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {256, 18} + }; + } + // Based on data_type = int8_t, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = int8_t, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on data_type = int8_t, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {256, 11} + }; + } + // Based on data_type = int8_t, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Based on data_type = int8_t, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {192, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_predicated_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double, flag_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = double, flag_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {128, 6} + }; + } + // Based on data_type = double, flag_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {128, 6} + }; + } + // Based on data_type = double, flag_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {128, 6} + }; + } + // Based on data_type = double, flag_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {128, 7} + }; + } + // Based on data_type = float, flag_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = float, flag_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on data_type = float, flag_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {128, 11} + }; + } + // Based on data_type = float, flag_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {256, 10} + }; + } + // Based on data_type = float, flag_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {256, 11} + }; + } + // Based on data_type = rocprim::half, flag_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = rocprim::half, flag_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on data_type = rocprim::half, flag_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {256, 12} + }; + } + // Based on data_type = rocprim::half, flag_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {512, 18} + }; + } + // Based on data_type = rocprim::half, flag_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {128, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {128, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {128, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {128, 4} + }; + } + // Based on data_type = int64_t, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = int64_t, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {128, 6} + }; + } + // Based on data_type = int64_t, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {128, 6} + }; + } + // Based on data_type = int64_t, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {128, 6} + }; + } + // Based on data_type = int64_t, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {128, 7} + }; + } + // Based on data_type = int, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = int, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on data_type = int, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {128, 11} + }; + } + // Based on data_type = int, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {256, 10} + }; + } + // Based on data_type = int, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {256, 11} + }; + } + // Based on data_type = short, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = short, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on data_type = short, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {256, 12} + }; + } + // Based on data_type = short, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {512, 18} + }; + } + // Based on data_type = short, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {256, 18} + }; + } + // Based on data_type = int8_t, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = int8_t, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on data_type = int8_t, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {256, 11} + }; + } + // Based on data_type = int8_t, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Based on data_type = int8_t, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_predicated_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double, flag_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {128, 4} + }; + } + // Based on data_type = double, flag_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = double, flag_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {512, 5} + }; + } + // Based on data_type = double, flag_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = double, flag_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = float, flag_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = float, flag_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {192, 5} + }; + } + // Based on data_type = float, flag_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on data_type = float, flag_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {192, 10} + }; + } + // Based on data_type = float, flag_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = rocprim::half, flag_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = rocprim::half, flag_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on data_type = rocprim::half, flag_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {256, 10} + }; + } + // Based on data_type = rocprim::half, flag_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {256, 20} + }; + } + // Based on data_type = rocprim::half, flag_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {256, 20} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {128, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {128, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {128, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {128, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {128, 4} + }; + } + // Based on data_type = int64_t, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {128, 4} + }; + } + // Based on data_type = int64_t, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = int64_t, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = int64_t, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = int64_t, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {192, 5} + }; + } + // Based on data_type = int, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = int, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {192, 5} + }; + } + // Based on data_type = int, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on data_type = int, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {192, 10} + }; + } + // Based on data_type = int, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {192, 10} + }; + } + // Based on data_type = short, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = short, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on data_type = short, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {256, 10} + }; + } + // Based on data_type = short, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {256, 20} + }; + } + // Based on data_type = short, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {256, 20} + }; + } + // Based on data_type = int8_t, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on data_type = int8_t, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {192, 5} + }; + } + // Based on data_type = int8_t, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {256, 12} + }; + } + // Based on data_type = int8_t, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {192, 20} + }; + } + // Based on data_type = int8_t, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_predicated_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double, flag_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = double, flag_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = double, flag_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = double, flag_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = double, flag_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = float, flag_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = float, flag_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = float, flag_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on data_type = float, flag_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on data_type = float, flag_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on data_type = rocprim::half, flag_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = rocprim::half, flag_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = rocprim::half, flag_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {512, 14} + }; + } + // Based on data_type = rocprim::half, flag_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {512, 30} + }; + } + // Based on data_type = rocprim::half, flag_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {512, 28} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = rocprim::int128_t, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = int64_t, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = int64_t, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = int64_t, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = int64_t, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = int64_t, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = int, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = int, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = int, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on data_type = int, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {512, 14} + }; + } + // Based on data_type = int, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on data_type = short, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 16) + && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = short, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 8) + && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = short, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 4) + && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {512, 14} + }; + } + // Based on data_type = short, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 2) + && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {512, 30} + }; + } + // Based on data_type = short, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1) && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {512, 28} + }; + } + // Based on data_type = int8_t, flag_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 16) && (sizeof(flag_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = int8_t, flag_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 8) && (sizeof(flag_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = int8_t, flag_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 4) && (sizeof(flag_type) > 2))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on data_type = int8_t, flag_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))) + { + return partition_config_params{ + {512, 28} + }; + } + // Based on data_type = int8_t, flag_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1) + && (sizeof(flag_type) <= 1))) + { + return partition_config_params{ + {512, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_predicated_flag_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + return select_predicated_flag_config_picker< + comp_target, + data_type, + flag_type>(); +} + +// All the existing configs should be auto generated +using select_predicated_flag_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail @@ -3553,4 +2443,4 @@ END_ROCPRIM_NAMESPACE /// @} // end of group primitivesmodule_deviceconfigs -#endif // ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SELECT_PREDICATED_FLAG_HPP_ \ No newline at end of file +#endif // ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SELECT_PREDICATED_FLAG_HPP_ diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_select_unique.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_select_unique.hpp index 6088fe77ea1..e8adf223d76 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_select_unique.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_select_unique.hpp @@ -40,608 +40,523 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_select_unique_config : default_partition_config_base::type -{}; - -// Based on data_type = double -template -struct default_select_unique_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 8> -{}; - -// Based on data_type = float -template -struct default_select_unique_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<512, 8> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_unique_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<512, 12> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_select_unique_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 4> -{}; - -// Based on data_type = int64_t -template -struct default_select_unique_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 8> -{}; - -// Based on data_type = int -template -struct default_select_unique_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 8> -{}; - -// Based on data_type = short -template -struct default_select_unique_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<384, 18> -{}; - -// Based on data_type = int8_t -template -struct default_select_unique_config< - static_cast(target_arch::gfx1030), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<384, 28> -{}; - -// Based on data_type = double -template -struct default_select_unique_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 8> -{}; - -// Based on data_type = float -template -struct default_select_unique_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<128, 14> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_unique_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<384, 22> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_select_unique_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<512, 4> -{}; - -// Based on data_type = int64_t -template -struct default_select_unique_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 8> -{}; - -// Based on data_type = int -template -struct default_select_unique_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<384, 16> -{}; - -// Based on data_type = short -template -struct default_select_unique_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<128, 20> -{}; - -// Based on data_type = int8_t -template -struct default_select_unique_config< - static_cast(target_arch::gfx1100), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<384, 28> -{}; - -// Based on data_type = double -template -struct default_select_unique_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<384, 8> -{}; - -// Based on data_type = float -template -struct default_select_unique_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<384, 9> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_unique_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 21> -{}; - -// Based on data_type = int64_t -template -struct default_select_unique_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<384, 7> -{}; - -// Based on data_type = int -template -struct default_select_unique_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<384, 9> -{}; - -// Based on data_type = short -template -struct default_select_unique_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 21> -{}; - -// Based on data_type = int8_t -template -struct default_select_unique_config< - static_cast(target_arch::gfx1200), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 20> -{}; - -// Based on data_type = double -template -struct default_select_unique_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<256, 7> -{}; - -// Based on data_type = float -template -struct default_select_unique_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 16> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_unique_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 18> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_select_unique_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 4> -{}; - -// Based on data_type = int64_t -template -struct default_select_unique_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<256, 7> -{}; - -// Based on data_type = int -template -struct default_select_unique_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 16> -{}; - -// Based on data_type = short -template -struct default_select_unique_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<192, 18> -{}; - -// Based on data_type = int8_t -template -struct default_select_unique_config< - static_cast(target_arch::gfx906), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<192, 18> -{}; - -// Based on data_type = double -template -struct default_select_unique_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 7> -{}; - -// Based on data_type = float -template -struct default_select_unique_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<512, 11> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_unique_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 30> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_select_unique_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<128, 4> -{}; - -// Based on data_type = int64_t -template -struct default_select_unique_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 7> -{}; - -// Based on data_type = int -template -struct default_select_unique_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 11> -{}; - -// Based on data_type = short -template -struct default_select_unique_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 30> -{}; - -// Based on data_type = int8_t -template -struct default_select_unique_config< - static_cast(target_arch::gfx908), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 28> -{}; - -// Based on data_type = double -template -struct default_select_unique_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<256, 5> -{}; - -// Based on data_type = float -template -struct default_select_unique_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<256, 10> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_unique_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 22> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_select_unique_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<256, 3> -{}; - -// Based on data_type = int64_t -template -struct default_select_unique_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<256, 5> -{}; - -// Based on data_type = int -template -struct default_select_unique_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<256, 10> -{}; - -// Based on data_type = short -template -struct default_select_unique_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 22> -{}; - -// Based on data_type = int8_t -template -struct default_select_unique_config< - static_cast(target_arch::gfx90a), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<128, 32> -{}; - -// Based on data_type = double -template -struct default_select_unique_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 7> -{}; - -// Based on data_type = float -template -struct default_select_unique_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<512, 11> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_unique_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<256, 30> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_select_unique_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<128, 4> -{}; - -// Based on data_type = int64_t -template -struct default_select_unique_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 7> -{}; - -// Based on data_type = int -template -struct default_select_unique_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 11> -{}; - -// Based on data_type = short -template -struct default_select_unique_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<256, 30> -{}; - -// Based on data_type = int8_t -template -struct default_select_unique_config< - static_cast(target_arch::unknown), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<256, 28> -{}; - -// Based on data_type = double -template -struct default_select_unique_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) - && (sizeof(data_type) > 4))>> : select_config<512, 7> -{}; - -// Based on data_type = float -template -struct default_select_unique_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) - && (sizeof(data_type) > 2))>> : select_config<512, 15> -{}; - -// Based on data_type = rocprim::half -template -struct default_select_unique_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2))>> : select_config<512, 30> -{}; - -// Based on data_type = rocprim::int128_t -template -struct default_select_unique_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 16) && (sizeof(data_type) > 8))>> - : select_config<384, 4> -{}; - -// Based on data_type = int64_t -template -struct default_select_unique_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 8) && (sizeof(data_type) > 4))>> - : select_config<512, 7> -{}; - -// Based on data_type = int -template -struct default_select_unique_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 4) && (sizeof(data_type) > 2))>> - : select_config<512, 15> -{}; - -// Based on data_type = short -template -struct default_select_unique_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 2) && (sizeof(data_type) > 1))>> - : select_config<512, 30> -{}; - -// Based on data_type = int8_t -template -struct default_select_unique_config< - static_cast(target_arch::gfx942), - data_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(data_type) <= 1))>> : select_config<512, 32> -{}; +template +constexpr auto select_unique_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {512, 12} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {384, 18} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {384, 28} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_unique_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {128, 14} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {384, 22} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {384, 16} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {128, 20} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {384, 28} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_unique_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {384, 8} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {384, 9} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 21} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {384, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {384, 9} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 21} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 20} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_unique_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 16} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 18} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 16} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {192, 18} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {192, 18} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_unique_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 11} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 30} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {128, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 11} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 30} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {256, 28} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_unique_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 5} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 10} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {256, 22} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {256, 3} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {256, 5} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {256, 10} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {256, 22} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {128, 32} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_unique_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on data_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on data_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2))) + { + return partition_config_params{ + {512, 30} + }; + } + // Based on data_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 16) + && (sizeof(data_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on data_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 8) + && (sizeof(data_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on data_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 4) + && (sizeof(data_type) > 2))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on data_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 2) + && (sizeof(data_type) > 1))) + { + return partition_config_params{ + {512, 30} + }; + } + // Based on data_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(data_type) <= 1))) + { + return partition_config_params{ + {512, 32} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_unique_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + return select_unique_config_picker< + comp_target, + data_type>(); +} + +// All the existing configs should be auto generated +using select_unique_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_select_unique_by_key.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_select_unique_by_key.hpp index 958e0c25f9a..7f0711166b8 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_select_unique_by_key.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_select_unique_by_key.hpp @@ -40,3382 +40,2401 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_select_unique_by_key_config : default_partition_config_base::type -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<256, 4> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<512, 8> -{}; - -// Based on key_type = double, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<512, 4> -{}; - -// Based on key_type = double, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<512, 8> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : select_config<512, 6> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<256, 4> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<512, 4> -{}; - -// Based on key_type = float, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<512, 8> -{}; - -// Based on key_type = float, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<512, 8> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : select_config<512, 8> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : select_config<256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : select_config<512, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : select_config<512, 8> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : select_config<384, 10> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1))>> : select_config<384, 14> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))>> - : select_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<512, 8> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<512, 4> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<512, 8> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : select_config<512, 6> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<256, 4> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<512, 4> -{}; - -// Based on key_type = int, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<512, 8> -{}; - -// Based on key_type = int, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<512, 8> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : select_config<512, 8> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<256, 4> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<512, 6> -{}; - -// Based on key_type = short, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<512, 8> -{}; - -// Based on key_type = short, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<384, 10> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))>> - : select_config<384, 14> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : select_config<256, 4> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : select_config<512, 6> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : select_config<512, 8> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : select_config<384, 14> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1))>> : select_config<384, 12> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<512, 4> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<512, 8> -{}; - -// Based on key_type = double, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<512, 8> -{}; - -// Based on key_type = double, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<384, 8> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : select_config<128, 7> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<512, 4> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<512, 8> -{}; - -// Based on key_type = float, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<384, 8> -{}; - -// Based on key_type = float, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<128, 10> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : select_config<128, 12> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : select_config<512, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : select_config<512, 16> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : select_config<512, 16> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1))>> : select_config<192, 16> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<512, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<512, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<512, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<512, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))>> - : select_config<512, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<512, 4> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<512, 8> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<512, 8> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<128, 8> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : select_config<128, 7> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<512, 4> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<512, 8> -{}; - -// Based on key_type = int, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<384, 8> -{}; - -// Based on key_type = int, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<128, 12> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : select_config<128, 12> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<192, 4> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<512, 8> -{}; - -// Based on key_type = short, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<128, 11> -{}; - -// Based on key_type = short, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<128, 12> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))>> - : select_config<512, 20> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : select_config<384, 8> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : select_config<128, 13> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : select_config<128, 22> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1))>> : select_config<384, 20> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<384, 8> -{}; - -// Based on key_type = double, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<384, 6> -{}; - -// Based on key_type = double, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<384, 6> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : select_config<384, 6> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<384, 4> -{}; - -// Based on key_type = float, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<384, 7> -{}; - -// Based on key_type = float, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<384, 9> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : select_config<384, 12> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : select_config<384, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : select_config<512, 16> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : select_config<512, 14> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1))>> : select_config<512, 14> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<384, 8> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<384, 6> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<384, 6> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : select_config<384, 8> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<384, 4> -{}; - -// Based on key_type = int, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<384, 7> -{}; - -// Based on key_type = int, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<384, 9> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : select_config<384, 12> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<384, 6> -{}; - -// Based on key_type = short, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<512, 16> -{}; - -// Based on key_type = short, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<512, 24> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))>> - : select_config<512, 14> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : select_config<256, 8> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : select_config<512, 16> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : select_config<512, 32> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1))>> : select_config<384, 16> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<192, 4> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<256, 5> -{}; - -// Based on key_type = double, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<256, 6> -{}; - -// Based on key_type = double, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<192, 7> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : select_config<192, 8> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<192, 4> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<256, 6> -{}; - -// Based on key_type = float, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<256, 13> -{}; - -// Based on key_type = float, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<192, 14> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : select_config<192, 16> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : select_config<256, 6> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : select_config<256, 12> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : select_config<192, 14> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1))>> : select_config<256, 16> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<256, 3> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<192, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<192, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<192, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))>> - : select_config<192, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<192, 4> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<192, 8> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<256, 6> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<192, 7> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : select_config<192, 8> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<192, 4> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<256, 6> -{}; - -// Based on key_type = int, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<256, 13> -{}; - -// Based on key_type = int, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<192, 14> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : select_config<192, 16> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<192, 4> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<256, 6> -{}; - -// Based on key_type = short, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<256, 12> -{}; - -// Based on key_type = short, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<256, 14> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))>> - : select_config<192, 16> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : select_config<192, 7> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : select_config<192, 14> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : select_config<256, 14> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1))>> : select_config<192, 17> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<192, 4> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<128, 5> -{}; - -// Based on key_type = double, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<128, 6> -{}; - -// Based on key_type = double, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<128, 6> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : select_config<192, 8> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<192, 4> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<128, 6> -{}; - -// Based on key_type = float, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<256, 13> -{}; - -// Based on key_type = float, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<256, 14> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : select_config<256, 16> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : select_config<256, 6> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : select_config<256, 12> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : select_config<256, 14> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1))>> : select_config<256, 16> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<256, 3> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<192, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<192, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<192, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))>> - : select_config<192, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<192, 4> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<128, 5> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<128, 6> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<192, 7> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : select_config<192, 8> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<192, 4> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<128, 6> -{}; - -// Based on key_type = int, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<256, 13> -{}; - -// Based on key_type = int, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<256, 14> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : select_config<192, 16> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<192, 4> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<256, 6> -{}; - -// Based on key_type = short, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<256, 12> -{}; - -// Based on key_type = short, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<256, 14> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))>> - : select_config<256, 14> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : select_config<192, 7> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : select_config<256, 13> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : select_config<256, 14> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1))>> : select_config<192, 17> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<192, 4> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<128, 4> -{}; - -// Based on key_type = double, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<192, 4> -{}; - -// Based on key_type = double, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<256, 4> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : select_config<256, 5> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<192, 4> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<192, 5> -{}; - -// Based on key_type = float, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<192, 7> -{}; - -// Based on key_type = float, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<192, 10> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : select_config<192, 9> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : select_config<192, 5> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : select_config<192, 10> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : select_config<256, 24> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1))>> : select_config<256, 24> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<256, 3> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<128, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<192, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<192, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))>> - : select_config<192, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<192, 4> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<128, 4> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<192, 4> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : select_config<256, 5> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<192, 4> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<192, 5> -{}; - -// Based on key_type = int, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<192, 7> -{}; - -// Based on key_type = int, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<192, 10> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : select_config<192, 9> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<192, 4> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<192, 5> -{}; - -// Based on key_type = short, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<192, 10> -{}; - -// Based on key_type = short, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<256, 24> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))>> - : select_config<256, 24> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : select_config<256, 6> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : select_config<192, 10> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : select_config<256, 28> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1))>> : select_config<192, 28> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<192, 4> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<128, 5> -{}; - -// Based on key_type = double, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<128, 6> -{}; - -// Based on key_type = double, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<128, 6> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : select_config<192, 8> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<192, 4> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<128, 6> -{}; - -// Based on key_type = float, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<256, 13> -{}; - -// Based on key_type = float, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<256, 14> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : select_config<256, 16> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : select_config<256, 6> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : select_config<256, 12> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : select_config<256, 14> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1))>> : select_config<256, 16> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<256, 3> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<192, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<192, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<192, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))>> - : select_config<192, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<192, 4> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<128, 5> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<128, 6> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<192, 7> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : select_config<192, 8> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<192, 4> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<128, 6> -{}; - -// Based on key_type = int, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<256, 13> -{}; - -// Based on key_type = int, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<256, 14> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : select_config<192, 16> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<192, 4> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<256, 6> -{}; - -// Based on key_type = short, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<256, 12> -{}; - -// Based on key_type = short, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<256, 14> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))>> - : select_config<256, 14> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : select_config<192, 4> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : select_config<192, 7> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : select_config<256, 13> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : select_config<256, 14> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1))>> : select_config<192, 17> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<384, 4> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<512, 7> -{}; - -// Based on key_type = double, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<512, 7> -{}; - -// Based on key_type = double, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<512, 7> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : select_config<512, 7> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<384, 4> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<512, 7> -{}; - -// Based on key_type = float, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<512, 15> -{}; - -// Based on key_type = float, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<512, 14> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : select_config<512, 15> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : select_config<384, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : select_config<512, 7> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : select_config<512, 14> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : select_config<512, 22> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1))>> : select_config<512, 24> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<384, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<384, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<384, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<384, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))>> - : select_config<384, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<384, 4> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<512, 7> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<512, 7> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<512, 7> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))>> - : select_config<512, 7> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<384, 4> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<512, 7> -{}; - -// Based on key_type = int, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<512, 15> -{}; - -// Based on key_type = int, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<512, 14> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))>> - : select_config<512, 15> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : select_config<384, 4> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : select_config<512, 7> -{}; - -// Based on key_type = short, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : select_config<512, 14> -{}; - -// Based on key_type = short, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : select_config<512, 22> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))>> - : select_config<512, 24> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : select_config<384, 4> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : select_config<512, 7> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : select_config<512, 15> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : select_config<512, 24> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_select_unique_by_key_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1))>> : select_config<512, 24> -{}; +template +constexpr auto select_unique_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {512, 6} + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {384, 10} + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {384, 14} + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {512, 6} + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {512, 6} + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {384, 10} + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {384, 14} + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {512, 6} + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {384, 14} + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {384, 12} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_unique_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {384, 8} + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {128, 7} + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {384, 8} + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {128, 10} + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {128, 12} + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {512, 16} + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {512, 16} + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {192, 16} + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {128, 8} + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {128, 7} + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {512, 4} + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {384, 8} + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {128, 12} + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {128, 12} + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {512, 8} + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {128, 11} + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {128, 12} + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {512, 20} + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {384, 8} + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {128, 13} + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {128, 22} + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {384, 20} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_unique_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {384, 8} + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {384, 6} + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {384, 6} + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {384, 6} + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {384, 7} + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {384, 9} + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {384, 12} + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {384, 8} + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {512, 16} + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {512, 14} + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {512, 14} + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {384, 8} + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {384, 6} + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {384, 6} + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {384, 8} + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {384, 7} + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {384, 9} + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {384, 12} + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {384, 6} + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {512, 16} + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {512, 24} + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {512, 14} + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {256, 8} + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {512, 16} + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {512, 32} + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {384, 16} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_unique_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {256, 5} + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {192, 7} + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {192, 8} + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {256, 13} + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {192, 14} + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {192, 16} + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {256, 12} + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {192, 14} + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {256, 16} + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {256, 3} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {192, 8} + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {192, 7} + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {192, 8} + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {256, 13} + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {192, 14} + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {192, 16} + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {256, 12} + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {256, 14} + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {192, 16} + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {192, 7} + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {192, 14} + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {256, 14} + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {192, 17} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_unique_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {128, 5} + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {128, 6} + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {128, 6} + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {192, 8} + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {128, 6} + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {256, 13} + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {256, 14} + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {256, 16} + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {256, 12} + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {256, 14} + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {256, 16} + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {256, 3} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {128, 5} + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {128, 6} + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {192, 7} + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {192, 8} + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {128, 6} + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {256, 13} + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {256, 14} + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {192, 16} + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {256, 12} + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {256, 14} + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {256, 14} + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {192, 7} + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {256, 13} + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {256, 14} + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {192, 17} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_unique_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {128, 4} + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {256, 5} + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {192, 5} + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {192, 7} + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {192, 10} + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {192, 9} + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {192, 5} + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {192, 10} + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {256, 3} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {128, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {128, 4} + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {256, 4} + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {256, 5} + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {192, 5} + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {192, 7} + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {192, 10} + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {192, 9} + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {192, 5} + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {192, 10} + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {256, 24} + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {192, 4} + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {256, 6} + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {192, 10} + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {256, 28} + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {192, 28} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_unique_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {512, 14} + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {512, 14} + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {512, 22} + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {512, 24} + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {512, 14} + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {512, 14} + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {512, 22} + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {512, 24} + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return partition_config_params{ + {384, 4} + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return partition_config_params{ + {512, 7} + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return partition_config_params{ + {512, 15} + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return partition_config_params{ + {512, 24} + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1))) + { + return partition_config_params{ + {512, 24} + }; + } + // Default case if none of the conditions match + return partition_config_params_base(); +} + +template +constexpr auto select_unique_by_key_config_picker() -> std::enable_if_t< + std::is_same>::value, + partition_config_params> +{ + return select_unique_by_key_config_picker< + comp_target, + key_type, + value_type>(); +} + +// All the existing configs should be auto generated +using select_unique_by_key_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail @@ -3424,4 +2443,4 @@ END_ROCPRIM_NAMESPACE /// @} // end of group primitivesmodule_deviceconfigs -#endif // ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SELECT_UNIQUE_BY_KEY_HPP_ \ No newline at end of file +#endif // ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SELECT_UNIQUE_BY_KEY_HPP_ diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/device_config_helper.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/device_config_helper.hpp index 474af8edd7c..5a63461667f 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/device_config_helper.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/device_config_helper.hpp @@ -1052,11 +1052,11 @@ namespace detail struct partition_config_params { - kernel_config_params kernel_config; - block_load_method key_block_load_method; - block_load_method value_block_load_method; - block_load_method flag_block_load_method; - block_scan_algorithm block_scan_method; + kernel_config_params kernel_config = {0, 0}; + block_load_method key_block_load_method = ::rocprim::block_load_method::block_load_transpose; + block_load_method value_block_load_method = ::rocprim::block_load_method::block_load_transpose; + block_load_method flag_block_load_method = ::rocprim::block_load_method::block_load_transpose; + block_scan_algorithm block_scan_method = ::rocprim::block_scan_algorithm::using_warp_scan; }; } // namespace detail @@ -1113,26 +1113,24 @@ struct select_config : public detail::partition_config_params namespace detail { -template -struct default_partition_config_base +template +constexpr partition_config_params partition_config_params_base() { - static constexpr unsigned int item_scale + constexpr int ItemScaleBase = 13; + constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div(sizeof(Key), sizeof(int)); using offset_t = std::conditional_t; // Additional shared memory is required by the lookback scan state. - static constexpr unsigned int shared_mem_offset = sizeof( + constexpr unsigned int shared_mem_offset = sizeof( typename offset_lookback_scan_prefix_op>::storage_type); - using type = select_config< - fallback_block_size<256U, sizeof(Key), ROCPRIM_WARP_SIZE_64, shared_mem_offset>::value, - ::rocprim::max(1u, ItemScaleBase / item_scale), - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_scan_algorithm::using_warp_scan>; + return partition_config_params{ + {fallback_block_size<256U, sizeof(Key), ROCPRIM_WARP_SIZE_64, shared_mem_offset>::value, + ::rocprim::max(1u, ItemScaleBase / item_scale)} + }; }; struct reduce_by_key_config_params diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_partition.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_partition.hpp index 0f9db2d5ad9..0b9ad04f7a0 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_partition.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_partition.hpp @@ -52,6 +52,7 @@ namespace detail { template -inline size_t get_partition_vsmem_size_per_block(detail::target_arch arch) +inline size_t get_partition_vsmem_size_per_block(detail::target t) { + using targets = typename Selector::targets; using offset_type = typename OffsetLookbackScanState::value_type; - std::optional vsmem_per_block; - for_each_arch( - [&](auto arch_tag) + + size_t vsmem_per_block = 0; + + targets::for_each( + [&](auto candidate) { - constexpr target_arch Arch = decltype(arch_tag)::value; - if(Arch != arch || vsmem_per_block) + if(target{candidate} == most_common_config(t)) { - return; + using ArchConfig = target_config2; + using partition_kernel_impl_t = partition_kernel_impl_; + + using partition_vsmem_helper_t = detail::vsmem_helper_impl; + vsmem_per_block = partition_vsmem_helper_t::vsmem_per_block; } - - using ArchConfig = typename Config::template architecture_config; - using partition_kernel_impl_t = partition_kernel_impl_; - using partition_vsmem_helper_t = detail::vsmem_helper_impl; - - vsmem_per_block = partition_vsmem_helper_t::vsmem_per_block; }); - if(!vsmem_per_block) - { - using ArchConfig = typename Config::template architecture_config; - using partition_kernel_impl_t = partition_kernel_impl_; - using partition_vsmem_helper_t = detail::vsmem_helper_impl; - - vsmem_per_block = partition_vsmem_helper_t::vsmem_per_block; - } - return vsmem_per_block.value(); + + return vsmem_per_block; } template; using block_id_type = detail::block_id_wrapper; - using config = wrapped_partition_config; + using selector = partition_config_selector; constexpr bool write_only_selected = SubAlgo == partition_subalgo::select_flag @@ -175,8 +162,11 @@ inline hipError_t partition_impl(void* temporary_storage, detail::target_arch target_arch; ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); - const partition_config_params params = dispatch_target_arch(target_arch); + detail::gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + const target get_target(target_arch, target_gpu); + const auto params = get_config(Config{}, get_target); const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_thread = params.kernel_config.items_per_thread; const auto items_per_block = block_size * items_per_thread; @@ -214,14 +204,15 @@ inline hipError_t partition_impl(void* temporary_storage, bool>::type; virtual_shared_memory_size - = get_partition_vsmem_size_per_block(target_arch); + block_id_type>(get_target); virtual_shared_memory_size *= number_of_blocks; // temporary storage partition @@ -346,12 +337,13 @@ inline hipError_t partition_impl(void* temporary_storage, storage, predicates...); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - partition_kernel, - dim3(current_number_of_blocks), - dim3(block_size), - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR( + execute_launch_plan(get_target, + partition_kernel, + dim3(current_number_of_blocks), + dim3(block_size), + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("partition_kernel", size, start); std::swap(selected_count, prev_selected_count); diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_partition_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_partition_config.hpp index 69c92f0d7ac..4c97615ff39 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_partition_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_partition_config.hpp @@ -42,245 +42,108 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -template -struct wrapped_partition_config +template +constexpr auto algo_target_type() { - template - struct architecture_config + if constexpr(Algo == partition_subalgo::partition_two_way_predicate) { - static constexpr partition_config_params params = PartitionConfig{}; - }; -}; - -template -struct wrapped_partition_config -{ - template - struct architecture_config + return type_identity{}; + } + else if constexpr(Algo == partition_subalgo::partition_two_way_flag) { - static constexpr partition_config_params params - = default_partition_two_way_predicate_config(Arch), - KeyType>{}; - }; -}; - -template -struct wrapped_partition_config -{ - template - struct architecture_config + return type_identity{}; + } + else if constexpr(Algo == partition_subalgo::partition_flag) { - static constexpr partition_config_params params - = default_partition_two_way_flag_config(Arch), KeyType>{}; - }; -}; - -template -struct wrapped_partition_config -{ - template - struct architecture_config + return type_identity{}; + } + else if constexpr(Algo == partition_subalgo::partition_predicate) { - static constexpr partition_config_params params - = default_partition_flag_config(Arch), KeyType>{}; - }; -}; - -template -struct wrapped_partition_config -{ - template - struct architecture_config + return type_identity{}; + } + else if constexpr(Algo == partition_subalgo::partition_three_way) { - static constexpr partition_config_params params - = default_partition_predicate_config(Arch), KeyType>{}; - }; -}; - -template -struct wrapped_partition_config -{ - template - struct architecture_config + return type_identity{}; + } + else if constexpr(Algo == partition_subalgo::select_flag) { - static constexpr partition_config_params params - = default_partition_three_way_config(Arch), KeyType>{}; - }; -}; - -template -struct wrapped_partition_config -{ - template - struct architecture_config + return type_identity{}; + } + else if constexpr(Algo == partition_subalgo::select_predicate) { - static constexpr partition_config_params params - = default_select_flag_config(Arch), KeyType>{}; - }; -}; - -template -struct wrapped_partition_config -{ - template - struct architecture_config + return type_identity{}; + } + else if constexpr(Algo == partition_subalgo::select_predicated_flag) { - static constexpr partition_config_params params - = default_select_predicate_config(Arch), KeyType>{}; - }; -}; - -template -struct wrapped_partition_config -{ - template - struct architecture_config + return type_identity{}; + } + else if constexpr(Algo == partition_subalgo::select_unique) { - static constexpr partition_config_params params - = default_select_predicated_flag_config(Arch), - KeyType, - ValueType>{}; - }; -}; - -template -struct wrapped_partition_config -{ - template - struct architecture_config + return type_identity{}; + } + else if constexpr(Algo == partition_subalgo::select_unique_by_key) { - static constexpr partition_config_params params - = default_select_unique_config(Arch), KeyType>{}; - }; -}; + return type_identity{}; + } +} -template -struct wrapped_partition_config +template +struct partition_config_selector { - template - struct architecture_config - { - static constexpr partition_config_params params - = default_select_unique_by_key_config(Arch), - KeyType, - ValueType>{}; - }; -}; - -#ifndef DOXYGEN_SHOULD_SKIP_THIS - -template -template -constexpr partition_config_params - wrapped_partition_config::architecture_config< - Arch>::params; - -template -template -constexpr partition_config_params - wrapped_partition_config::architecture_config::params; + using targets = typename decltype(algo_target_type())::type; + using param_type = partition_config_params; -template -template -constexpr partition_config_params - wrapped_partition_config::architecture_config::params; + param_type params; -template -template -constexpr partition_config_params - wrapped_partition_config::architecture_config::params; - -template -template -constexpr partition_config_params - wrapped_partition_config::architecture_config::params; - -template -template -constexpr partition_config_params - wrapped_partition_config::architecture_config::params; - -template -template -constexpr partition_config_params - wrapped_partition_config:: - architecture_config::params; - -template -template -constexpr partition_config_params - wrapped_partition_config::architecture_config::params; - -template -template -constexpr partition_config_params - wrapped_partition_config::architecture_config::params; - -template -template -constexpr partition_config_params - wrapped_partition_config::architecture_config::params; - -template -template -constexpr partition_config_params - wrapped_partition_config::architecture_config::params; - -#endif // DOXYGEN_SHOULD_SKIP_THIS + template + constexpr param_type picker_helper() + { + if constexpr(SubAlgo == partition_subalgo::partition_two_way_predicate) + { + return partition_two_way_predicate_config_picker(); + } + else if constexpr(SubAlgo == partition_subalgo::partition_two_way_flag) + { + return partition_two_way_flag_config_picker(); + } + else if constexpr(SubAlgo == partition_subalgo::partition_flag) + { + return partition_flag_config_picker(); + } + else if constexpr(SubAlgo == partition_subalgo::partition_predicate) + { + return partition_predicate_config_picker(); + } + else if constexpr(SubAlgo == partition_subalgo::partition_three_way) + { + return partition_three_way_config_picker(); + } + else if constexpr(SubAlgo == partition_subalgo::select_flag) + { + return select_flag_config_picker(); + } + else if constexpr(SubAlgo == partition_subalgo::select_predicate) + { + return select_predicate_config_picker(); + } + else if constexpr(SubAlgo == partition_subalgo::select_predicated_flag) + { + return select_predicated_flag_config_picker(); + } + else if constexpr(SubAlgo == partition_subalgo::select_unique) + { + return select_unique_config_picker(); + } + else if constexpr(SubAlgo == partition_subalgo::select_unique_by_key) + { + return select_unique_by_key_config_picker(); + } + } + + template + constexpr partition_config_selector(Target) : params(picker_helper()) + {} +}; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_segmented_radix_sort.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_segmented_radix_sort.hpp index 98f7a252b07..213170d0723 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_segmented_radix_sort.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_segmented_radix_sort.hpp @@ -80,7 +80,9 @@ struct Partitioner using input_type = typename std::iterator_traits::value_type; if(three_way_partitioning) { - using config = typename default_partition_config_base::type; + constexpr auto params = partition_config_params_base(); + using config = select_config; return partition_three_way(temporary_storage, storage_size, input, @@ -96,7 +98,9 @@ struct Partitioner } else { - using config = typename default_partition_config_base::type; + constexpr auto params = partition_config_params_base(); + using config = select_config; return partition(temporary_storage, storage_size, input, From 20d93ca80e765c43720be94720ef1c4be8263eac Mon Sep 17 00:00:00 2001 From: Saiyang Zhang Date: Wed, 12 Nov 2025 17:24:46 +0000 Subject: [PATCH 06/26] Resolve "Update configs for new config system part 1" --- .../benchmark_device_batch_memcpy.cpp | 15 +- .../include/rocprim/device/config_types.hpp | 21 +- .../config/device_adjacent_difference.hpp | 1161 +++++++-------- .../device_adjacent_difference_inplace.hpp | 1161 +++++++-------- .../detail/config/device_adjacent_find.hpp | 1294 ++++++++--------- .../detail/config/device_batch_copy.hpp | 593 ++++---- .../detail/config/device_batch_memcpy.hpp | 593 ++++---- .../detail/config/device_find_first_of.hpp | 766 +++++----- .../device/detail/device_batch_memcpy.hpp | 39 +- .../device/detail/device_config_helper.hpp | 71 +- .../device/device_adjacent_difference.hpp | 28 +- .../device_adjacent_difference_config.hpp | 73 +- .../rocprim/device/device_adjacent_find.hpp | 15 +- .../device/device_adjacent_find_config.hpp | 76 +- .../rocprim/device/device_find_first_of.hpp | 21 +- .../device/device_find_first_of_config.hpp | 41 +- .../rocprim/device/device_memcpy_config.hpp | 66 +- .../test/rocprim/test_device_batch_memcpy.cpp | 15 +- 18 files changed, 2927 insertions(+), 3122 deletions(-) diff --git a/projects/rocprim/benchmark/benchmark_device_batch_memcpy.cpp b/projects/rocprim/benchmark/benchmark_device_batch_memcpy.cpp index 2f11c1112de..310c4eed8c9 100644 --- a/projects/rocprim/benchmark/benchmark_device_batch_memcpy.cpp +++ b/projects/rocprim/benchmark/benchmark_device_batch_memcpy.cpp @@ -133,18 +133,23 @@ BatchMemcpyData prepare_data(hipStream_t stre BatchMemcpyData result; - using config - = rocprim::detail::wrapped_batch_memcpy_config; + using Selector = rocprim::detail::batch_memcpy_config_selector; rocprim::detail::target_arch target_arch; - hipError_t success = rocprim::detail::host_target_arch(stream, target_arch); + hipError_t success = host_target_arch(stream, target_arch); + + rocprim::detail::gpu target_gpu; + success = host_target_gpu(stream, target_gpu); + if(success != hipSuccess) { return result; } - const rocprim::detail::batch_memcpy_config_params params - = rocprim::detail::dispatch_target_arch(target_arch); + const rocprim::detail::target get_target(target_arch, target_gpu); + + const auto params + = rocprim::detail::get_config(rocprim::default_config{}, get_target); const int32_t wlev_min_size = params.wlev_size_threshold; const int32_t blev_min_size = params.blev_size_threshold; diff --git a/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp b/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp index 879923c57dc..5b9b53e488f 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp @@ -428,20 +428,6 @@ struct default_config_selector = Config::template architecture_config::params.kernel_config.block_size; }; -template -struct non_blev_batch_memcpy_config_selector -{ - static constexpr unsigned int block_size = Config::template architecture_config::params - .non_blev_batch_memcpy_kernel_config.block_size; -}; - -template -struct blev_batch_memcpy_config_selector -{ - static constexpr unsigned int block_size = Config::template architecture_config::params - .blev_batch_memcpy_kernel_config.block_size; -}; - template struct histogram_config_selector { @@ -664,9 +650,9 @@ void trampoline_kernel(Kernel kernel) } template class LaunchSelector = default_config_selector> + template class LaunchSelector = default_config_static_selector, + class Kernel> auto make_launch_plan(target target_current, Kernel kernel) -> launch_plan { using Targets = typename ConfigSelector::targets; @@ -716,8 +702,7 @@ template(t, kernel); + const auto launch_plan = make_launch_plan(t, kernel); launch_plan.launch(grid_size, block_size, shmem, stream); return hipGetLastError(); } diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_adjacent_difference.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_adjacent_difference.hpp index 875c58f1454..590cadc9df0 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_adjacent_difference.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_adjacent_difference.hpp @@ -40,635 +40,538 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_adjacent_difference_config - : default_adjacent_difference_config_base::type -{}; - -// Based on value_type = double -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<1024, 1> -{}; - -// Based on value_type = float -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<1024, 1> -{}; - -// Based on value_type = rocprim::half -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<1024, 2> -{}; - -// Based on value_type = int64_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<1024, 1> -{}; - -// Based on value_type = int -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<1024, 1> -{}; - -// Based on value_type = short -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<1024, 5> -{}; - -// Based on value_type = int8_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<64, 7> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : adjacent_difference_config<512, 1> -{}; - -// Based on value_type = double -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<1024, 2> -{}; - -// Based on value_type = float -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<1024, 5> -{}; - -// Based on value_type = rocprim::half -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<64, 7> -{}; - -// Based on value_type = int64_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<512, 1> -{}; - -// Based on value_type = int -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<1024, 5> -{}; - -// Based on value_type = short -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<64, 7> -{}; - -// Based on value_type = int8_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<64, 17> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : adjacent_difference_config<256, 1> -{}; - -// Based on value_type = double -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<1024, 2> -{}; - -// Based on value_type = float -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<1024, 5> -{}; - -// Based on value_type = rocprim::half -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<64, 17> -{}; - -// Based on value_type = int64_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<128, 2> -{}; - -// Based on value_type = int -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<1024, 5> -{}; - -// Based on value_type = short -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<64, 7> -{}; - -// Based on value_type = int8_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<64, 19> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : adjacent_difference_config<512, 1> -{}; - -// Based on value_type = double -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<128, 1> -{}; - -// Based on value_type = float -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<128, 3> -{}; - -// Based on value_type = rocprim::half -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<128, 7> -{}; - -// Based on value_type = int64_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<128, 1> -{}; - -// Based on value_type = int -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<128, 2> -{}; - -// Based on value_type = short -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<64, 7> -{}; - -// Based on value_type = int8_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<64, 19> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : adjacent_difference_config<32, 17> -{}; - -// Based on value_type = double -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<256, 2> -{}; - -// Based on value_type = float -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<512, 5> -{}; - -// Based on value_type = rocprim::half -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<256, 11> -{}; - -// Based on value_type = int64_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<256, 2> -{}; - -// Based on value_type = int -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<512, 5> -{}; - -// Based on value_type = short -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<64, 17> -{}; - -// Based on value_type = int8_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<64, 19> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : adjacent_difference_config<256, 1> -{}; - -// Based on value_type = double -template -struct default_adjacent_difference_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<128, 1> -{}; - -// Based on value_type = float -template -struct default_adjacent_difference_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<128, 3> -{}; - -// Based on value_type = rocprim::half -template -struct default_adjacent_difference_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<128, 7> -{}; - -// Based on value_type = int64_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<128, 1> -{}; - -// Based on value_type = int -template -struct default_adjacent_difference_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<128, 2> -{}; - -// Based on value_type = short -template -struct default_adjacent_difference_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<64, 7> -{}; - -// Based on value_type = int8_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<64, 19> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : adjacent_difference_config<32, 17> -{}; - -// Based on value_type = double -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<64, 13> -{}; - -// Based on value_type = float -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<64, 31> -{}; - -// Based on value_type = rocprim::half -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<64, 19> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : adjacent_difference_config<512, 2> -{}; - -// Based on value_type = int64_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<64, 13> -{}; - -// Based on value_type = int -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<64, 29> -{}; - -// Based on value_type = short -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<256, 17> -{}; - -// Based on value_type = int8_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<128, 19> -{}; - -// Based on value_type = double -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<32, 31> -{}; - -// Based on value_type = float -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<1024, 5> -{}; - -// Based on value_type = rocprim::half -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<256, 17> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : adjacent_difference_config<128, 3> -{}; - -// Based on value_type = int64_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<128, 1> -{}; - -// Based on value_type = int -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<128, 29> -{}; - -// Based on value_type = short -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<128, 17> -{}; - -// Based on value_type = int8_t -template -struct default_adjacent_difference_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<1024, 3> -{}; +template +constexpr auto adjacent_difference_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_difference_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {1024, 1} + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {1024, 1} + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return adjacent_difference_config_params{ + {1024, 2} + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {1024, 1} + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {1024, 1} + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return adjacent_difference_config_params{ + {1024, 5} + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return adjacent_difference_config_params{ + {64, 7} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return adjacent_difference_config_params{ + {512, 1} + }; + } + // Default case if none of the conditions match + return adjacent_difference_config_params_base(); +} + +template +constexpr auto adjacent_difference_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_difference_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {1024, 2} + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {1024, 5} + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return adjacent_difference_config_params{ + {64, 7} + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {512, 1} + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {1024, 5} + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return adjacent_difference_config_params{ + {64, 7} + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return adjacent_difference_config_params{ + {64, 17} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return adjacent_difference_config_params{ + {256, 1} + }; + } + // Default case if none of the conditions match + return adjacent_difference_config_params_base(); +} + +template +constexpr auto adjacent_difference_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_difference_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {1024, 2} + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {1024, 5} + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return adjacent_difference_config_params{ + {64, 17} + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {128, 2} + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {1024, 5} + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return adjacent_difference_config_params{ + {64, 7} + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return adjacent_difference_config_params{ + {64, 19} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return adjacent_difference_config_params{ + {512, 1} + }; + } + // Default case if none of the conditions match + return adjacent_difference_config_params_base(); +} + +template +constexpr auto adjacent_difference_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_difference_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {128, 1} + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {128, 3} + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return adjacent_difference_config_params{ + {128, 7} + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {128, 1} + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {128, 2} + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return adjacent_difference_config_params{ + {64, 7} + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return adjacent_difference_config_params{ + {64, 19} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return adjacent_difference_config_params{ + {32, 17} + }; + } + // Default case if none of the conditions match + return adjacent_difference_config_params_base(); +} + +template +constexpr auto adjacent_difference_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_difference_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {256, 2} + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {512, 5} + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return adjacent_difference_config_params{ + {256, 11} + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {256, 2} + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {512, 5} + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return adjacent_difference_config_params{ + {64, 17} + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return adjacent_difference_config_params{ + {64, 19} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return adjacent_difference_config_params{ + {256, 1} + }; + } + // Default case if none of the conditions match + return adjacent_difference_config_params_base(); +} + +template +constexpr auto adjacent_difference_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_difference_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {64, 13} + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {64, 31} + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return adjacent_difference_config_params{ + {64, 19} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return adjacent_difference_config_params{ + {512, 2} + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {64, 13} + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {64, 29} + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return adjacent_difference_config_params{ + {256, 17} + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return adjacent_difference_config_params{ + {128, 19} + }; + } + // Default case if none of the conditions match + return adjacent_difference_config_params_base(); +} + +template +constexpr auto adjacent_difference_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_difference_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {32, 31} + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {1024, 5} + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return adjacent_difference_config_params{ + {256, 17} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return adjacent_difference_config_params{ + {128, 3} + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {128, 1} + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {128, 29} + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return adjacent_difference_config_params{ + {128, 17} + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return adjacent_difference_config_params{ + {1024, 3} + }; + } + // Default case if none of the conditions match + return adjacent_difference_config_params_base(); +} + +template +constexpr auto adjacent_difference_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_difference_config_params> +{ + return adjacent_difference_config_picker< + comp_target, + value_type>(); +} + +// All the existing configs should be auto generated +using adjacent_difference_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_adjacent_difference_inplace.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_adjacent_difference_inplace.hpp index 492a0e71bfa..81327e83498 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_adjacent_difference_inplace.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_adjacent_difference_inplace.hpp @@ -40,635 +40,538 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_adjacent_difference_inplace_config - : default_adjacent_difference_config_base::type -{}; - -// Based on value_type = double -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<32, 17> -{}; - -// Based on value_type = float -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<32, 17> -{}; - -// Based on value_type = rocprim::half -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<1024, 13> -{}; - -// Based on value_type = int64_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<32, 17> -{}; - -// Based on value_type = int -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<32, 29> -{}; - -// Based on value_type = short -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<512, 11> -{}; - -// Based on value_type = int8_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<1024, 11> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : adjacent_difference_config<64, 29> -{}; - -// Based on value_type = double -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<128, 29> -{}; - -// Based on value_type = float -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<128, 29> -{}; - -// Based on value_type = rocprim::half -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<128, 23> -{}; - -// Based on value_type = int64_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<128, 29> -{}; - -// Based on value_type = int -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<128, 29> -{}; - -// Based on value_type = short -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<128, 23> -{}; - -// Based on value_type = int8_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<512, 31> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : adjacent_difference_config<32, 29> -{}; - -// Based on value_type = double -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<128, 11> -{}; - -// Based on value_type = float -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<256, 11> -{}; - -// Based on value_type = rocprim::half -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<64, 31> -{}; - -// Based on value_type = int64_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<128, 11> -{}; - -// Based on value_type = int -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<256, 11> -{}; - -// Based on value_type = short -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<64, 17> -{}; - -// Based on value_type = int8_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<64, 19> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : adjacent_difference_config<64, 29> -{}; - -// Based on value_type = double -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<512, 2> -{}; - -// Based on value_type = float -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<1024, 5> -{}; - -// Based on value_type = rocprim::half -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<512, 7> -{}; - -// Based on value_type = int64_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<512, 2> -{}; - -// Based on value_type = int -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<1024, 5> -{}; - -// Based on value_type = short -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<32, 23> -{}; - -// Based on value_type = int8_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<64, 19> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : adjacent_difference_config<32, 19> -{}; - -// Based on value_type = double -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<512, 2> -{}; - -// Based on value_type = float -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<512, 5> -{}; - -// Based on value_type = rocprim::half -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<256, 11> -{}; - -// Based on value_type = int64_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<512, 2> -{}; - -// Based on value_type = int -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<512, 5> -{}; - -// Based on value_type = short -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<64, 17> -{}; - -// Based on value_type = int8_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<128, 19> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : adjacent_difference_config<512, 2> -{}; - -// Based on value_type = double -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<512, 2> -{}; - -// Based on value_type = float -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<1024, 5> -{}; - -// Based on value_type = rocprim::half -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<512, 7> -{}; - -// Based on value_type = int64_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<512, 2> -{}; - -// Based on value_type = int -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<1024, 5> -{}; - -// Based on value_type = short -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<32, 23> -{}; - -// Based on value_type = int8_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<64, 19> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : adjacent_difference_config<32, 19> -{}; - -// Based on value_type = double -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<128, 2> -{}; - -// Based on value_type = float -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<256, 3> -{}; - -// Based on value_type = rocprim::half -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<512, 13> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : adjacent_difference_config<32, 17> -{}; - -// Based on value_type = int64_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<256, 11> -{}; - -// Based on value_type = int -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<512, 5> -{}; - -// Based on value_type = short -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<64, 31> -{}; - -// Based on value_type = int8_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<128, 23> -{}; - -// Based on value_type = double -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<256, 11> -{}; - -// Based on value_type = float -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<128, 17> -{}; - -// Based on value_type = rocprim::half -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2))>> : adjacent_difference_config<32, 17> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : adjacent_difference_config<32, 29> -{}; - -// Based on value_type = int64_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : adjacent_difference_config<32, 7> -{}; - -// Based on value_type = int -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : adjacent_difference_config<32, 3> -{}; - -// Based on value_type = short -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : adjacent_difference_config<512, 2> -{}; - -// Based on value_type = int8_t -template -struct default_adjacent_difference_inplace_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1))>> : adjacent_difference_config<128, 11> -{}; +template +constexpr auto adjacent_difference_inplace_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_difference_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {32, 17} + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {32, 17} + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return adjacent_difference_config_params{ + {1024, 13} + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {32, 17} + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {32, 29} + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return adjacent_difference_config_params{ + {512, 11} + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return adjacent_difference_config_params{ + {1024, 11} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return adjacent_difference_config_params{ + {64, 29} + }; + } + // Default case if none of the conditions match + return adjacent_difference_config_params_base(); +} + +template +constexpr auto adjacent_difference_inplace_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_difference_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {128, 29} + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {128, 29} + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return adjacent_difference_config_params{ + {128, 23} + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {128, 29} + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {128, 29} + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return adjacent_difference_config_params{ + {128, 23} + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return adjacent_difference_config_params{ + {512, 31} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return adjacent_difference_config_params{ + {32, 29} + }; + } + // Default case if none of the conditions match + return adjacent_difference_config_params_base(); +} + +template +constexpr auto adjacent_difference_inplace_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_difference_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {128, 11} + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {256, 11} + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return adjacent_difference_config_params{ + {64, 31} + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {128, 11} + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {256, 11} + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return adjacent_difference_config_params{ + {64, 17} + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return adjacent_difference_config_params{ + {64, 19} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return adjacent_difference_config_params{ + {64, 29} + }; + } + // Default case if none of the conditions match + return adjacent_difference_config_params_base(); +} + +template +constexpr auto adjacent_difference_inplace_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_difference_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {512, 2} + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {1024, 5} + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return adjacent_difference_config_params{ + {512, 7} + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {512, 2} + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {1024, 5} + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return adjacent_difference_config_params{ + {32, 23} + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return adjacent_difference_config_params{ + {64, 19} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return adjacent_difference_config_params{ + {32, 19} + }; + } + // Default case if none of the conditions match + return adjacent_difference_config_params_base(); +} + +template +constexpr auto adjacent_difference_inplace_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_difference_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {512, 2} + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {512, 5} + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return adjacent_difference_config_params{ + {256, 11} + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {512, 2} + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {512, 5} + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return adjacent_difference_config_params{ + {64, 17} + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return adjacent_difference_config_params{ + {128, 19} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return adjacent_difference_config_params{ + {512, 2} + }; + } + // Default case if none of the conditions match + return adjacent_difference_config_params_base(); +} + +template +constexpr auto adjacent_difference_inplace_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_difference_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {128, 2} + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {256, 3} + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return adjacent_difference_config_params{ + {512, 13} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return adjacent_difference_config_params{ + {32, 17} + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {256, 11} + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {512, 5} + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return adjacent_difference_config_params{ + {64, 31} + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return adjacent_difference_config_params{ + {128, 23} + }; + } + // Default case if none of the conditions match + return adjacent_difference_config_params_base(); +} + +template +constexpr auto adjacent_difference_inplace_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_difference_config_params> +{ + // Based on value_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {256, 11} + }; + } + // Based on value_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {128, 17} + }; + } + // Based on value_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2))) + { + return adjacent_difference_config_params{ + {32, 17} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return adjacent_difference_config_params{ + {32, 29} + }; + } + // Based on value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return adjacent_difference_config_params{ + {32, 7} + }; + } + // Based on value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return adjacent_difference_config_params{ + {32, 3} + }; + } + // Based on value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return adjacent_difference_config_params{ + {512, 2} + }; + } + // Based on value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(value_type) <= 1))) + { + return adjacent_difference_config_params{ + {128, 11} + }; + } + // Default case if none of the conditions match + return adjacent_difference_config_params_base(); +} + +template +constexpr auto adjacent_difference_inplace_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_difference_config_params> +{ + return adjacent_difference_inplace_config_picker< + comp_target, + value_type>(); +} + +// All the existing configs should be auto generated +using adjacent_difference_inplace_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_adjacent_find.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_adjacent_find.hpp index 8d601e061d3..84a59486271 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_adjacent_find.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_adjacent_find.hpp @@ -40,702 +40,604 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_adjacent_find_config : default_adjacent_find_config_base::type -{}; - -// Based on input_type = double -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1030), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> - : adjacent_find_config<512, 2> -{}; - -// Based on input_type = float -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1030), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> - : adjacent_find_config<256, 16> -{}; - -// Based on input_type = rocprim::half -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1030), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 2))>> : adjacent_find_config<256, 16> -{}; - -// Based on input_type = int64_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1030), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> - : adjacent_find_config<512, 2> -{}; - -// Based on input_type = int -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1030), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> - : adjacent_find_config<256, 8> -{}; - -// Based on input_type = short -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1030), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 2) && (sizeof(input_type) > 1))>> - : adjacent_find_config<256, 16> -{}; - -// Based on input_type = int8_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1030), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 1))>> : adjacent_find_config<128, 16> -{}; - -// Based on input_type = rocprim::int128_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1030), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 16) && (sizeof(input_type) > 8))>> - : adjacent_find_config<512, 2> -{}; - -// Based on input_type = double -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1100), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> - : adjacent_find_config<512, 2> -{}; - -// Based on input_type = float -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1100), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> - : adjacent_find_config<512, 8> -{}; - -// Based on input_type = rocprim::half -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1100), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 2))>> : adjacent_find_config<512, 16> -{}; - -// Based on input_type = int64_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1100), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> - : adjacent_find_config<512, 8> -{}; - -// Based on input_type = int -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1100), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> - : adjacent_find_config<512, 8> -{}; - -// Based on input_type = short -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1100), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 2) && (sizeof(input_type) > 1))>> - : adjacent_find_config<512, 16> -{}; - -// Based on input_type = int8_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1100), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 1))>> : adjacent_find_config<128, 32> -{}; - -// Based on input_type = rocprim::int128_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1100), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 16) && (sizeof(input_type) > 8))>> - : adjacent_find_config<1024, 4> -{}; - -// Based on input_type = double -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1200), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> - : adjacent_find_config<64, 2> -{}; - -// Based on input_type = float -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1200), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> - : adjacent_find_config<128, 4> -{}; - -// Based on input_type = rocprim::half -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1200), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 2))>> : adjacent_find_config<64, 64> -{}; - -// Based on input_type = int64_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1200), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> - : adjacent_find_config<64, 4> -{}; - -// Based on input_type = int -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1200), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> - : adjacent_find_config<128, 32> -{}; - -// Based on input_type = short -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1200), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 2) && (sizeof(input_type) > 1))>> - : adjacent_find_config<256, 32> -{}; - -// Based on input_type = int8_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1200), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 1))>> : adjacent_find_config<256, 64> -{}; - -// Based on input_type = double -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1201), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> - : adjacent_find_config<64, 2> -{}; - -// Based on input_type = float -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1201), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> - : adjacent_find_config<64, 32> -{}; - -// Based on input_type = rocprim::half -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1201), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 2))>> : adjacent_find_config<256, 16> -{}; - -// Based on input_type = int64_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1201), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> - : adjacent_find_config<512, 2> -{}; - -// Based on input_type = int -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1201), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> - : adjacent_find_config<64, 16> -{}; - -// Based on input_type = short -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1201), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 2) && (sizeof(input_type) > 1))>> - : adjacent_find_config<256, 8> -{}; - -// Based on input_type = int8_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1201), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 1))>> : adjacent_find_config<64, 16> -{}; - -// Based on input_type = double -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx906), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> - : adjacent_find_config<64, 32> -{}; - -// Based on input_type = float -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx906), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> - : adjacent_find_config<128, 16> -{}; - -// Based on input_type = rocprim::half -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx906), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 2))>> : adjacent_find_config<64, 16> -{}; - -// Based on input_type = int64_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx906), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> - : adjacent_find_config<128, 4> -{}; - -// Based on input_type = int -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx906), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> - : adjacent_find_config<128, 16> -{}; - -// Based on input_type = short -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx906), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 2) && (sizeof(input_type) > 1))>> - : adjacent_find_config<64, 16> -{}; - -// Based on input_type = int8_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx906), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 1))>> : adjacent_find_config<64, 16> -{}; - -// Based on input_type = rocprim::int128_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx906), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 16) && (sizeof(input_type) > 8))>> - : adjacent_find_config<1024, 4> -{}; - -// Based on input_type = double -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx908), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> - : adjacent_find_config<128, 8> -{}; - -// Based on input_type = float -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx908), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> - : adjacent_find_config<64, 16> -{}; - -// Based on input_type = rocprim::half -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx908), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 2))>> : adjacent_find_config<64, 16> -{}; - -// Based on input_type = int64_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx908), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> - : adjacent_find_config<128, 32> -{}; - -// Based on input_type = int -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx908), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> - : adjacent_find_config<512, 4> -{}; - -// Based on input_type = short -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx908), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 2) && (sizeof(input_type) > 1))>> - : adjacent_find_config<64, 16> -{}; - -// Based on input_type = int8_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx908), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 1))>> : adjacent_find_config<64, 16> -{}; - -// Based on input_type = rocprim::int128_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx908), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 16) && (sizeof(input_type) > 8))>> - : adjacent_find_config<1024, 4> -{}; - -// Based on input_type = double -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx90a), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> - : adjacent_find_config<64, 8> -{}; - -// Based on input_type = float -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx90a), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> - : adjacent_find_config<64, 16> -{}; - -// Based on input_type = rocprim::half -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx90a), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 2))>> : adjacent_find_config<128, 16> -{}; - -// Based on input_type = rocprim::int128_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx90a), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 16) && (sizeof(input_type) > 8))>> - : adjacent_find_config<128, 32> -{}; - -// Based on input_type = int64_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx90a), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> - : adjacent_find_config<64, 8> -{}; - -// Based on input_type = int -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx90a), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> - : adjacent_find_config<64, 16> -{}; - -// Based on input_type = short -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx90a), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 2) && (sizeof(input_type) > 1))>> - : adjacent_find_config<128, 16> -{}; - -// Based on input_type = int8_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx90a), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 1))>> : adjacent_find_config<64, 16> -{}; - -// Based on input_type = double -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx942), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> - : adjacent_find_config<1024, 16> -{}; - -// Based on input_type = float -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx942), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> - : adjacent_find_config<512, 32> -{}; - -// Based on input_type = rocprim::half -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx942), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 2))>> : adjacent_find_config<512, 32> -{}; - -// Based on input_type = rocprim::int128_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx942), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 16) && (sizeof(input_type) > 8))>> - : adjacent_find_config<1024, 16> -{}; - -// Based on input_type = int64_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx942), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> - : adjacent_find_config<1024, 16> -{}; - -// Based on input_type = int -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx942), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> - : adjacent_find_config<512, 32> -{}; - -// Based on input_type = short -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx942), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 2) && (sizeof(input_type) > 1))>> - : adjacent_find_config<512, 32> -{}; - -// Based on input_type = int8_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx942), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 1))>> : adjacent_find_config<512, 32> -{}; - -// Based on input_type = double -template -struct default_adjacent_find_config< - static_cast(target_arch::unknown), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> - : adjacent_find_config<128, 8> -{}; - -// Based on input_type = float -template -struct default_adjacent_find_config< - static_cast(target_arch::unknown), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> - : adjacent_find_config<64, 16> -{}; - -// Based on input_type = rocprim::half -template -struct default_adjacent_find_config< - static_cast(target_arch::unknown), - input_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 2))>> : adjacent_find_config<64, 16> -{}; - -// Based on input_type = int64_t -template -struct default_adjacent_find_config< - static_cast(target_arch::unknown), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 8) && (sizeof(input_type) > 4))>> - : adjacent_find_config<128, 32> -{}; - -// Based on input_type = int -template -struct default_adjacent_find_config< - static_cast(target_arch::unknown), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 4) && (sizeof(input_type) > 2))>> - : adjacent_find_config<64, 64> -{}; - -// Based on input_type = short -template -struct default_adjacent_find_config< - static_cast(target_arch::unknown), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 2) && (sizeof(input_type) > 1))>> - : adjacent_find_config<64, 16> -{}; - -// Based on input_type = int8_t -template -struct default_adjacent_find_config< - static_cast(target_arch::unknown), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 1))>> : adjacent_find_config<64, 16> -{}; - -// Based on input_type = rocprim::int128_t -template -struct default_adjacent_find_config< - static_cast(target_arch::unknown), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 16) && (sizeof(input_type) > 8))>> - : adjacent_find_config<1024, 4> -{}; - -// Based on input_type = rocprim::int128_t -template -struct default_adjacent_find_config< - static_cast(target_arch::gfx1201), - input_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(input_type) <= 16) && (sizeof(input_type) > 8))>> - : adjacent_find_config<128, 16> -{}; +template +constexpr auto adjacent_find_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_find_config_params> +{ + // Based on input_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 8) + && (sizeof(input_type) > 4))) + { + return adjacent_find_config_params{ + {512, 2} + }; + } + // Based on input_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 4) + && (sizeof(input_type) > 2))) + { + return adjacent_find_config_params{ + {256, 16} + }; + } + // Based on input_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 2))) + { + return adjacent_find_config_params{ + {256, 16} + }; + } + // Based on input_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 8) + && (sizeof(input_type) > 4))) + { + return adjacent_find_config_params{ + {512, 2} + }; + } + // Based on input_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 4) + && (sizeof(input_type) > 2))) + { + return adjacent_find_config_params{ + {256, 8} + }; + } + // Based on input_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 2) + && (sizeof(input_type) > 1))) + { + return adjacent_find_config_params{ + {256, 16} + }; + } + // Based on input_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 1))) + { + return adjacent_find_config_params{ + {128, 16} + }; + } + // Based on input_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 16) + && (sizeof(input_type) > 8))) + { + return adjacent_find_config_params{ + {512, 2} + }; + } + // Default case if none of the conditions match + return adjacent_find_config_params_base(); +} + +template +constexpr auto adjacent_find_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_find_config_params> +{ + // Based on input_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 8) + && (sizeof(input_type) > 4))) + { + return adjacent_find_config_params{ + {512, 2} + }; + } + // Based on input_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 4) + && (sizeof(input_type) > 2))) + { + return adjacent_find_config_params{ + {512, 8} + }; + } + // Based on input_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 2))) + { + return adjacent_find_config_params{ + {512, 16} + }; + } + // Based on input_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 8) + && (sizeof(input_type) > 4))) + { + return adjacent_find_config_params{ + {512, 8} + }; + } + // Based on input_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 4) + && (sizeof(input_type) > 2))) + { + return adjacent_find_config_params{ + {512, 8} + }; + } + // Based on input_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 2) + && (sizeof(input_type) > 1))) + { + return adjacent_find_config_params{ + {512, 16} + }; + } + // Based on input_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 1))) + { + return adjacent_find_config_params{ + {128, 32} + }; + } + // Based on input_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 16) + && (sizeof(input_type) > 8))) + { + return adjacent_find_config_params{ + {1024, 4} + }; + } + // Default case if none of the conditions match + return adjacent_find_config_params_base(); +} + +template +constexpr auto adjacent_find_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_find_config_params> +{ + // Based on input_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 8) + && (sizeof(input_type) > 4))) + { + return adjacent_find_config_params{ + {64, 2} + }; + } + // Based on input_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 4) + && (sizeof(input_type) > 2))) + { + return adjacent_find_config_params{ + {128, 4} + }; + } + // Based on input_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 2))) + { + return adjacent_find_config_params{ + {64, 64} + }; + } + // Based on input_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 8) + && (sizeof(input_type) > 4))) + { + return adjacent_find_config_params{ + {64, 4} + }; + } + // Based on input_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 4) + && (sizeof(input_type) > 2))) + { + return adjacent_find_config_params{ + {128, 32} + }; + } + // Based on input_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 2) + && (sizeof(input_type) > 1))) + { + return adjacent_find_config_params{ + {256, 32} + }; + } + // Based on input_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 1))) + { + return adjacent_find_config_params{ + {256, 64} + }; + } + // Default case if none of the conditions match + return adjacent_find_config_params_base(); +} + +template +constexpr auto adjacent_find_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_find_config_params> +{ + // Based on input_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 8) + && (sizeof(input_type) > 4))) + { + return adjacent_find_config_params{ + {64, 2} + }; + } + // Based on input_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 4) + && (sizeof(input_type) > 2))) + { + return adjacent_find_config_params{ + {64, 32} + }; + } + // Based on input_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 2))) + { + return adjacent_find_config_params{ + {256, 16} + }; + } + // Based on input_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 8) + && (sizeof(input_type) > 4))) + { + return adjacent_find_config_params{ + {512, 2} + }; + } + // Based on input_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 4) + && (sizeof(input_type) > 2))) + { + return adjacent_find_config_params{ + {64, 16} + }; + } + // Based on input_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 2) + && (sizeof(input_type) > 1))) + { + return adjacent_find_config_params{ + {256, 8} + }; + } + // Based on input_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 1))) + { + return adjacent_find_config_params{ + {64, 16} + }; + } + // Based on input_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 16) + && (sizeof(input_type) > 8))) + { + return adjacent_find_config_params{ + {128, 16} + }; + } + // Default case if none of the conditions match + return adjacent_find_config_params_base(); +} + +template +constexpr auto adjacent_find_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_find_config_params> +{ + // Based on input_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 8) + && (sizeof(input_type) > 4))) + { + return adjacent_find_config_params{ + {64, 32} + }; + } + // Based on input_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 4) + && (sizeof(input_type) > 2))) + { + return adjacent_find_config_params{ + {128, 16} + }; + } + // Based on input_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 2))) + { + return adjacent_find_config_params{ + {64, 16} + }; + } + // Based on input_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 8) + && (sizeof(input_type) > 4))) + { + return adjacent_find_config_params{ + {128, 4} + }; + } + // Based on input_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 4) + && (sizeof(input_type) > 2))) + { + return adjacent_find_config_params{ + {128, 16} + }; + } + // Based on input_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 2) + && (sizeof(input_type) > 1))) + { + return adjacent_find_config_params{ + {64, 16} + }; + } + // Based on input_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 1))) + { + return adjacent_find_config_params{ + {64, 16} + }; + } + // Based on input_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 16) + && (sizeof(input_type) > 8))) + { + return adjacent_find_config_params{ + {1024, 4} + }; + } + // Default case if none of the conditions match + return adjacent_find_config_params_base(); +} + +template +constexpr auto adjacent_find_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_find_config_params> +{ + // Based on input_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 8) + && (sizeof(input_type) > 4))) + { + return adjacent_find_config_params{ + {128, 8} + }; + } + // Based on input_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 4) + && (sizeof(input_type) > 2))) + { + return adjacent_find_config_params{ + {64, 16} + }; + } + // Based on input_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 2))) + { + return adjacent_find_config_params{ + {64, 16} + }; + } + // Based on input_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 8) + && (sizeof(input_type) > 4))) + { + return adjacent_find_config_params{ + {128, 32} + }; + } + // Based on input_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 4) + && (sizeof(input_type) > 2))) + { + return adjacent_find_config_params{ + {512, 4} + }; + } + // Based on input_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 2) + && (sizeof(input_type) > 1))) + { + return adjacent_find_config_params{ + {64, 16} + }; + } + // Based on input_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 1))) + { + return adjacent_find_config_params{ + {64, 16} + }; + } + // Based on input_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 16) + && (sizeof(input_type) > 8))) + { + return adjacent_find_config_params{ + {1024, 4} + }; + } + // Default case if none of the conditions match + return adjacent_find_config_params_base(); +} + +template +constexpr auto adjacent_find_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_find_config_params> +{ + // Based on input_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 8) + && (sizeof(input_type) > 4))) + { + return adjacent_find_config_params{ + {64, 8} + }; + } + // Based on input_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 4) + && (sizeof(input_type) > 2))) + { + return adjacent_find_config_params{ + {64, 16} + }; + } + // Based on input_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 2))) + { + return adjacent_find_config_params{ + {128, 16} + }; + } + // Based on input_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 16) + && (sizeof(input_type) > 8))) + { + return adjacent_find_config_params{ + {128, 32} + }; + } + // Based on input_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 8) + && (sizeof(input_type) > 4))) + { + return adjacent_find_config_params{ + {64, 8} + }; + } + // Based on input_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 4) + && (sizeof(input_type) > 2))) + { + return adjacent_find_config_params{ + {64, 16} + }; + } + // Based on input_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 2) + && (sizeof(input_type) > 1))) + { + return adjacent_find_config_params{ + {128, 16} + }; + } + // Based on input_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 1))) + { + return adjacent_find_config_params{ + {64, 16} + }; + } + // Default case if none of the conditions match + return adjacent_find_config_params_base(); +} + +template +constexpr auto adjacent_find_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_find_config_params> +{ + // Based on input_type = double + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 8) + && (sizeof(input_type) > 4))) + { + return adjacent_find_config_params{ + {1024, 16} + }; + } + // Based on input_type = float + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 4) + && (sizeof(input_type) > 2))) + { + return adjacent_find_config_params{ + {512, 32} + }; + } + // Based on input_type = rocprim::half + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 2))) + { + return adjacent_find_config_params{ + {512, 32} + }; + } + // Based on input_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 16) + && (sizeof(input_type) > 8))) + { + return adjacent_find_config_params{ + {1024, 16} + }; + } + // Based on input_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 8) + && (sizeof(input_type) > 4))) + { + return adjacent_find_config_params{ + {1024, 16} + }; + } + // Based on input_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 4) + && (sizeof(input_type) > 2))) + { + return adjacent_find_config_params{ + {512, 32} + }; + } + // Based on input_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(input_type) <= 2) + && (sizeof(input_type) > 1))) + { + return adjacent_find_config_params{ + {512, 32} + }; + } + // Based on input_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) + && (sizeof(input_type) <= 1))) + { + return adjacent_find_config_params{ + {512, 32} + }; + } + // Default case if none of the conditions match + return adjacent_find_config_params_base(); +} + +template +constexpr auto adjacent_find_config_picker() -> std::enable_if_t< + std::is_same>::value, + adjacent_find_config_params> +{ + return adjacent_find_config_picker< + comp_target, + input_type>(); +} + +// All the existing configs should be auto generated +using adjacent_find_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_batch_copy.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_batch_copy.hpp index e38b1621b4f..99d1f468357 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_batch_copy.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_batch_copy.hpp @@ -38,255 +38,350 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_batch_copy_config : default_batch_memcpy_config_base::type -{}; - -// Based on value_type = int64_t -template -struct default_batch_copy_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int -template -struct default_batch_copy_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = short -template -struct default_batch_copy_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int8_t -template -struct default_batch_copy_config(target_arch::gfx906), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int64_t -template -struct default_batch_copy_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int -template -struct default_batch_copy_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = short -template -struct default_batch_copy_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int8_t -template -struct default_batch_copy_config(target_arch::gfx908), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int64_t -template -struct default_batch_copy_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int -template -struct default_batch_copy_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = short -template -struct default_batch_copy_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int8_t -template -struct default_batch_copy_config(target_arch::gfx1030), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int64_t -template -struct default_batch_copy_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int -template -struct default_batch_copy_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = short -template -struct default_batch_copy_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int8_t -template -struct default_batch_copy_config(target_arch::gfx1100), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int64_t -template -struct default_batch_copy_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int -template -struct default_batch_copy_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = short -template -struct default_batch_copy_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int8_t -template -struct default_batch_copy_config(target_arch::unknown), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int64_t -template -struct default_batch_copy_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int -template -struct default_batch_copy_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = short -template -struct default_batch_copy_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int8_t -template -struct default_batch_copy_config(target_arch::gfx90a), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int64_t -template -struct default_batch_copy_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int -template -struct default_batch_copy_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = short -template -struct default_batch_copy_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int8_t -template -struct default_batch_copy_config(target_arch::gfx942), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : batch_copy_config<256, 2, 8, 128, 32, 128, 1024> -{}; +template +constexpr auto batch_copy_config_picker() -> std::enable_if_t< + std::is_same>::value, + batch_memcpy_config_params> +{ + // Based on value_type = int64_t + if constexpr(((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int + if constexpr(((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = short + if constexpr(((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int8_t + if constexpr(((sizeof(value_type) <= 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Default case if none of the conditions match + return batch_memcpy_config_params_base(); +} + +template +constexpr auto batch_copy_config_picker() -> std::enable_if_t< + std::is_same>::value, + batch_memcpy_config_params> +{ + // Based on value_type = int64_t + if constexpr(((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int + if constexpr(((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = short + if constexpr(((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int8_t + if constexpr(((sizeof(value_type) <= 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Default case if none of the conditions match + return batch_memcpy_config_params_base(); +} + +template +constexpr auto batch_copy_config_picker() -> std::enable_if_t< + std::is_same>::value, + batch_memcpy_config_params> +{ + // Based on value_type = int64_t + if constexpr(((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int + if constexpr(((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = short + if constexpr(((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int8_t + if constexpr(((sizeof(value_type) <= 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Default case if none of the conditions match + return batch_memcpy_config_params_base(); +} + +template +constexpr auto batch_copy_config_picker() -> std::enable_if_t< + std::is_same>::value, + batch_memcpy_config_params> +{ + // Based on value_type = int64_t + if constexpr(((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int + if constexpr(((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = short + if constexpr(((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int8_t + if constexpr(((sizeof(value_type) <= 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Default case if none of the conditions match + return batch_memcpy_config_params_base(); +} + +template +constexpr auto batch_copy_config_picker() -> std::enable_if_t< + std::is_same>::value, + batch_memcpy_config_params> +{ + // Based on value_type = int64_t + if constexpr(((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int + if constexpr(((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = short + if constexpr(((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int8_t + if constexpr(((sizeof(value_type) <= 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Default case if none of the conditions match + return batch_memcpy_config_params_base(); +} + +template +constexpr auto batch_copy_config_picker() -> std::enable_if_t< + std::is_same>::value, + batch_memcpy_config_params> +{ + // Based on value_type = int64_t + if constexpr(((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int + if constexpr(((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = short + if constexpr(((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int8_t + if constexpr(((sizeof(value_type) <= 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Default case if none of the conditions match + return batch_memcpy_config_params_base(); +} + +template +constexpr auto batch_copy_config_picker() -> std::enable_if_t< + std::is_same>::value, + batch_memcpy_config_params> +{ + return batch_copy_config_picker< + comp_target, + value_type>(); +} + +// All the existing configs should be auto generated +using batch_copy_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_batch_memcpy.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_batch_memcpy.hpp index a3d1b21a33e..efeb5e0e689 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_batch_memcpy.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_batch_memcpy.hpp @@ -38,255 +38,350 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_batch_memcpy_config : default_batch_memcpy_config_base::type -{}; - -// Based on value_type = int64_t -template -struct default_batch_memcpy_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int -template -struct default_batch_memcpy_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = short -template -struct default_batch_memcpy_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int8_t -template -struct default_batch_memcpy_config(target_arch::gfx906), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int64_t -template -struct default_batch_memcpy_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int -template -struct default_batch_memcpy_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = short -template -struct default_batch_memcpy_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int8_t -template -struct default_batch_memcpy_config(target_arch::gfx908), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int64_t -template -struct default_batch_memcpy_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int -template -struct default_batch_memcpy_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = short -template -struct default_batch_memcpy_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int8_t -template -struct default_batch_memcpy_config(target_arch::gfx1030), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int64_t -template -struct default_batch_memcpy_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int -template -struct default_batch_memcpy_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = short -template -struct default_batch_memcpy_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int8_t -template -struct default_batch_memcpy_config(target_arch::gfx1100), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int64_t -template -struct default_batch_memcpy_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int -template -struct default_batch_memcpy_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = short -template -struct default_batch_memcpy_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int8_t -template -struct default_batch_memcpy_config(target_arch::unknown), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int64_t -template -struct default_batch_memcpy_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int -template -struct default_batch_memcpy_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = short -template -struct default_batch_memcpy_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int8_t -template -struct default_batch_memcpy_config(target_arch::gfx90a), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int64_t -template -struct default_batch_memcpy_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int -template -struct default_batch_memcpy_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = short -template -struct default_batch_memcpy_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; - -// Based on value_type = int8_t -template -struct default_batch_memcpy_config(target_arch::gfx942), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : batch_memcpy_config<256, 2, 8, 128, 32, 128, 1024> -{}; +template +constexpr auto batch_memcpy_config_picker() -> std::enable_if_t< + std::is_same>::value, + batch_memcpy_config_params> +{ + // Based on value_type = int64_t + if constexpr(((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int + if constexpr(((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = short + if constexpr(((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int8_t + if constexpr(((sizeof(value_type) <= 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Default case if none of the conditions match + return batch_memcpy_config_params_base(); +} + +template +constexpr auto batch_memcpy_config_picker() -> std::enable_if_t< + std::is_same>::value, + batch_memcpy_config_params> +{ + // Based on value_type = int64_t + if constexpr(((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int + if constexpr(((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = short + if constexpr(((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int8_t + if constexpr(((sizeof(value_type) <= 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Default case if none of the conditions match + return batch_memcpy_config_params_base(); +} + +template +constexpr auto batch_memcpy_config_picker() -> std::enable_if_t< + std::is_same>::value, + batch_memcpy_config_params> +{ + // Based on value_type = int64_t + if constexpr(((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int + if constexpr(((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = short + if constexpr(((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int8_t + if constexpr(((sizeof(value_type) <= 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Default case if none of the conditions match + return batch_memcpy_config_params_base(); +} + +template +constexpr auto batch_memcpy_config_picker() -> std::enable_if_t< + std::is_same>::value, + batch_memcpy_config_params> +{ + // Based on value_type = int64_t + if constexpr(((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int + if constexpr(((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = short + if constexpr(((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int8_t + if constexpr(((sizeof(value_type) <= 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Default case if none of the conditions match + return batch_memcpy_config_params_base(); +} + +template +constexpr auto batch_memcpy_config_picker() -> std::enable_if_t< + std::is_same>::value, + batch_memcpy_config_params> +{ + // Based on value_type = int64_t + if constexpr(((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int + if constexpr(((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = short + if constexpr(((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int8_t + if constexpr(((sizeof(value_type) <= 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Default case if none of the conditions match + return batch_memcpy_config_params_base(); +} + +template +constexpr auto batch_memcpy_config_picker() -> std::enable_if_t< + std::is_same>::value, + batch_memcpy_config_params> +{ + // Based on value_type = int64_t + if constexpr(((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int + if constexpr(((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = short + if constexpr(((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Based on value_type = int8_t + if constexpr(((sizeof(value_type) <= 1))) + { + return batch_memcpy_config_params{ + {256, 2}, + 8, + {128, 32}, + 128, + 1024 + }; + } + // Default case if none of the conditions match + return batch_memcpy_config_params_base(); +} + +template +constexpr auto batch_memcpy_config_picker() -> std::enable_if_t< + std::is_same>::value, + batch_memcpy_config_params> +{ + return batch_memcpy_config_picker< + comp_target, + value_type>(); +} + +// All the existing configs should be auto generated +using batch_memcpy_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_find_first_of.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_find_first_of.hpp index 929d552f288..d60b0c2f86d 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_find_first_of.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_find_first_of.hpp @@ -40,397 +40,381 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_find_first_of_config : default_find_first_of_config_base::type -{}; - -// Based on value_type = int64_t -template -struct default_find_first_of_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : find_first_of_config<256, 10> -{}; - -// Based on value_type = int -template -struct default_find_first_of_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : find_first_of_config<256, 12> -{}; - -// Based on value_type = short -template -struct default_find_first_of_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : find_first_of_config<256, 12> -{}; - -// Based on value_type = int8_t -template -struct default_find_first_of_config(target_arch::gfx1030), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : find_first_of_config<64, 15> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_find_first_of_config< - static_cast(target_arch::gfx1030), - value_type, - std::enable_if_t<((sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : find_first_of_config<256, 4> -{}; - -// Based on value_type = int64_t -template -struct default_find_first_of_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : find_first_of_config<256, 9> -{}; - -// Based on value_type = int -template -struct default_find_first_of_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : find_first_of_config<128, 13> -{}; - -// Based on value_type = short -template -struct default_find_first_of_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : find_first_of_config<256, 9> -{}; - -// Based on value_type = int8_t -template -struct default_find_first_of_config(target_arch::gfx1100), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : find_first_of_config<64, 13> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_find_first_of_config< - static_cast(target_arch::gfx1100), - value_type, - std::enable_if_t<((sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : find_first_of_config<256, 4> -{}; - -// Based on value_type = int64_t -template -struct default_find_first_of_config< - static_cast(target_arch::gfx1200), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : find_first_of_config<128, 15> -{}; - -// Based on value_type = int -template -struct default_find_first_of_config< - static_cast(target_arch::gfx1200), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : find_first_of_config<256, 12> -{}; - -// Based on value_type = short -template -struct default_find_first_of_config< - static_cast(target_arch::gfx1200), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : find_first_of_config<128, 16> -{}; - -// Based on value_type = int8_t -template -struct default_find_first_of_config(target_arch::gfx1200), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : find_first_of_config<256, 16> -{}; - -// Based on value_type = int64_t -template -struct default_find_first_of_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : find_first_of_config<256, 6> -{}; - -// Based on value_type = int -template -struct default_find_first_of_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : find_first_of_config<256, 10> -{}; - -// Based on value_type = short -template -struct default_find_first_of_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : find_first_of_config<64, 8> -{}; - -// Based on value_type = int8_t -template -struct default_find_first_of_config(target_arch::gfx1201), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : find_first_of_config<128, 16> -{}; - -// Based on value_type = int64_t -template -struct default_find_first_of_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : find_first_of_config<256, 15> -{}; - -// Based on value_type = int -template -struct default_find_first_of_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : find_first_of_config<1024, 14> -{}; - -// Based on value_type = short -template -struct default_find_first_of_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : find_first_of_config<64, 16> -{}; - -// Based on value_type = int8_t -template -struct default_find_first_of_config(target_arch::gfx906), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : find_first_of_config<256, 11> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_find_first_of_config< - static_cast(target_arch::gfx906), - value_type, - std::enable_if_t<((sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : find_first_of_config<256, 4> -{}; - -// Based on value_type = int64_t -template -struct default_find_first_of_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : find_first_of_config<256, 8> -{}; - -// Based on value_type = int -template -struct default_find_first_of_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : find_first_of_config<256, 10> -{}; - -// Based on value_type = short -template -struct default_find_first_of_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : find_first_of_config<256, 11> -{}; - -// Based on value_type = int8_t -template -struct default_find_first_of_config(target_arch::gfx908), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : find_first_of_config<256, 10> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_find_first_of_config< - static_cast(target_arch::gfx908), - value_type, - std::enable_if_t<((sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : find_first_of_config<256, 4> -{}; - -// Based on value_type = int64_t -template -struct default_find_first_of_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : find_first_of_config<256, 6> -{}; - -// Based on value_type = int -template -struct default_find_first_of_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : find_first_of_config<128, 9> -{}; - -// Based on value_type = short -template -struct default_find_first_of_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : find_first_of_config<256, 15> -{}; - -// Based on value_type = int8_t -template -struct default_find_first_of_config(target_arch::gfx90a), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : find_first_of_config<256, 10> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_find_first_of_config< - static_cast(target_arch::gfx90a), - value_type, - std::enable_if_t<((sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : find_first_of_config<256, 4> -{}; - -// Based on value_type = int64_t -template -struct default_find_first_of_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : find_first_of_config<256, 8> -{}; - -// Based on value_type = int -template -struct default_find_first_of_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : find_first_of_config<256, 10> -{}; - -// Based on value_type = short -template -struct default_find_first_of_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : find_first_of_config<256, 11> -{}; - -// Based on value_type = int8_t -template -struct default_find_first_of_config(target_arch::unknown), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : find_first_of_config<256, 10> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_find_first_of_config< - static_cast(target_arch::unknown), - value_type, - std::enable_if_t<((sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : find_first_of_config<256, 4> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_find_first_of_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<((sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : find_first_of_config<1024, 6> -{}; - -// Based on value_type = int64_t -template -struct default_find_first_of_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : find_first_of_config<1024, 7> -{}; - -// Based on value_type = int -template -struct default_find_first_of_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : find_first_of_config<1024, 6> -{}; - -// Based on value_type = short -template -struct default_find_first_of_config< - static_cast(target_arch::gfx942), - value_type, - std::enable_if_t<((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : find_first_of_config<1024, 9> -{}; - -// Based on value_type = int8_t -template -struct default_find_first_of_config(target_arch::gfx942), - value_type, - std::enable_if_t<((sizeof(value_type) <= 1))>> - : find_first_of_config<1024, 11> -{}; - -// Based on value_type = rocprim::int128_t -template -struct default_find_first_of_config< - static_cast(target_arch::gfx1201), - value_type, - std::enable_if_t<((sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : find_first_of_config<256, 3> -{}; +template +constexpr auto find_first_of_config_picker() -> std::enable_if_t< + std::is_same>::value, + find_first_of_config_params> +{ + // Based on value_type = int64_t + if constexpr(((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return find_first_of_config_params{ + {256, 10} + }; + } + // Based on value_type = int + if constexpr(((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return find_first_of_config_params{ + {256, 12} + }; + } + // Based on value_type = short + if constexpr(((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return find_first_of_config_params{ + {256, 12} + }; + } + // Based on value_type = int8_t + if constexpr(((sizeof(value_type) <= 1))) + { + return find_first_of_config_params{ + {64, 15} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr(((sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return find_first_of_config_params{ + {256, 4} + }; + } + // Default case if none of the conditions match + return find_first_of_config_params_base(); +} + +template +constexpr auto find_first_of_config_picker() -> std::enable_if_t< + std::is_same>::value, + find_first_of_config_params> +{ + // Based on value_type = int64_t + if constexpr(((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return find_first_of_config_params{ + {256, 9} + }; + } + // Based on value_type = int + if constexpr(((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return find_first_of_config_params{ + {128, 13} + }; + } + // Based on value_type = short + if constexpr(((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return find_first_of_config_params{ + {256, 9} + }; + } + // Based on value_type = int8_t + if constexpr(((sizeof(value_type) <= 1))) + { + return find_first_of_config_params{ + {64, 13} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr(((sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return find_first_of_config_params{ + {256, 4} + }; + } + // Default case if none of the conditions match + return find_first_of_config_params_base(); +} + +template +constexpr auto find_first_of_config_picker() -> std::enable_if_t< + std::is_same>::value, + find_first_of_config_params> +{ + // Based on value_type = int64_t + if constexpr(((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return find_first_of_config_params{ + {128, 15} + }; + } + // Based on value_type = int + if constexpr(((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return find_first_of_config_params{ + {256, 12} + }; + } + // Based on value_type = short + if constexpr(((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return find_first_of_config_params{ + {128, 16} + }; + } + // Based on value_type = int8_t + if constexpr(((sizeof(value_type) <= 1))) + { + return find_first_of_config_params{ + {256, 16} + }; + } + // Default case if none of the conditions match + return find_first_of_config_params_base(); +} + +template +constexpr auto find_first_of_config_picker() -> std::enable_if_t< + std::is_same>::value, + find_first_of_config_params> +{ + // Based on value_type = int64_t + if constexpr(((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return find_first_of_config_params{ + {256, 6} + }; + } + // Based on value_type = int + if constexpr(((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return find_first_of_config_params{ + {256, 10} + }; + } + // Based on value_type = short + if constexpr(((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return find_first_of_config_params{ + {64, 8} + }; + } + // Based on value_type = int8_t + if constexpr(((sizeof(value_type) <= 1))) + { + return find_first_of_config_params{ + {128, 16} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr(((sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return find_first_of_config_params{ + {256, 3} + }; + } + // Default case if none of the conditions match + return find_first_of_config_params_base(); +} + +template +constexpr auto find_first_of_config_picker() -> std::enable_if_t< + std::is_same>::value, + find_first_of_config_params> +{ + // Based on value_type = int64_t + if constexpr(((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return find_first_of_config_params{ + {256, 15} + }; + } + // Based on value_type = int + if constexpr(((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return find_first_of_config_params{ + {1024, 14} + }; + } + // Based on value_type = short + if constexpr(((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return find_first_of_config_params{ + {64, 16} + }; + } + // Based on value_type = int8_t + if constexpr(((sizeof(value_type) <= 1))) + { + return find_first_of_config_params{ + {256, 11} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr(((sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return find_first_of_config_params{ + {256, 4} + }; + } + // Default case if none of the conditions match + return find_first_of_config_params_base(); +} + +template +constexpr auto find_first_of_config_picker() -> std::enable_if_t< + std::is_same>::value, + find_first_of_config_params> +{ + // Based on value_type = int64_t + if constexpr(((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return find_first_of_config_params{ + {256, 8} + }; + } + // Based on value_type = int + if constexpr(((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return find_first_of_config_params{ + {256, 10} + }; + } + // Based on value_type = short + if constexpr(((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return find_first_of_config_params{ + {256, 11} + }; + } + // Based on value_type = int8_t + if constexpr(((sizeof(value_type) <= 1))) + { + return find_first_of_config_params{ + {256, 10} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr(((sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return find_first_of_config_params{ + {256, 4} + }; + } + // Default case if none of the conditions match + return find_first_of_config_params_base(); +} + +template +constexpr auto find_first_of_config_picker() -> std::enable_if_t< + std::is_same>::value, + find_first_of_config_params> +{ + // Based on value_type = int64_t + if constexpr(((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return find_first_of_config_params{ + {256, 6} + }; + } + // Based on value_type = int + if constexpr(((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return find_first_of_config_params{ + {128, 9} + }; + } + // Based on value_type = short + if constexpr(((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return find_first_of_config_params{ + {256, 15} + }; + } + // Based on value_type = int8_t + if constexpr(((sizeof(value_type) <= 1))) + { + return find_first_of_config_params{ + {256, 10} + }; + } + // Based on value_type = rocprim::int128_t + if constexpr(((sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return find_first_of_config_params{ + {256, 4} + }; + } + // Default case if none of the conditions match + return find_first_of_config_params_base(); +} + +template +constexpr auto find_first_of_config_picker() -> std::enable_if_t< + std::is_same>::value, + find_first_of_config_params> +{ + // Based on value_type = rocprim::int128_t + if constexpr(((sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return find_first_of_config_params{ + {1024, 6} + }; + } + // Based on value_type = int64_t + if constexpr(((sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return find_first_of_config_params{ + {1024, 7} + }; + } + // Based on value_type = int + if constexpr(((sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return find_first_of_config_params{ + {1024, 6} + }; + } + // Based on value_type = short + if constexpr(((sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return find_first_of_config_params{ + {1024, 9} + }; + } + // Based on value_type = int8_t + if constexpr(((sizeof(value_type) <= 1))) + { + return find_first_of_config_params{ + {1024, 11} + }; + } + // Default case if none of the conditions match + return find_first_of_config_params_base(); +} + +template +constexpr auto find_first_of_config_picker() -> std::enable_if_t< + std::is_same>::value, + find_first_of_config_params> +{ + return find_first_of_config_picker< + comp_target, + value_type>(); +} + +// All the existing configs should be auto generated +using find_first_of_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/device_batch_memcpy.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/device_batch_memcpy.hpp index 75c0fdebf93..abc72f8528c 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/device_batch_memcpy.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/device_batch_memcpy.hpp @@ -1104,16 +1104,17 @@ static hipError_t batch_memcpy_func(void* temporary_storage, ROCPRIM_RETURN_ON_ERROR(std::visit( [&](auto use_atomic_block_id) { - using config = wrapped_batch_memcpy_config< - Config, - typename std::iterator_traits::value_type, - IsMemCpy>; + using Selector = batch_memcpy_config_selector; - detail::target_arch target_arch; - ROCPRIM_RETURN_ON_ERROR(detail::host_target_arch(stream, target_arch)); + target_arch target_arch; + ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); - const detail::batch_memcpy_config_params params - = detail::dispatch_target_arch(target_arch); + gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + + const target current_target(target_arch, target_gpu); + + const auto params = get_config(Config{}, current_target); using BufferOffsetType = unsigned int; using BlockOffsetType = unsigned int; @@ -1241,10 +1242,9 @@ static hipError_t batch_memcpy_func(void* temporary_storage, }; auto blev_memcpy_launch_plan - = make_launch_plan(target_arch, - blev_memcpy_kernel); + = make_launch_plan( + current_target, + blev_memcpy_kernel); int blev_occupancy{}; hipError_t error @@ -1320,14 +1320,13 @@ static hipError_t batch_memcpy_func(void* temporary_storage, start_timer(); ROCPRIM_RETURN_ON_ERROR( - execute_launch_plan(target_arch, - non_blev_memcpy_kernel, - batch_memcpy_grid_size, - non_blev_block_size, - 0, - stream)); + execute_launch_plan( + current_target, + non_blev_memcpy_kernel, + batch_memcpy_grid_size, + non_blev_block_size, + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("non_blev_memcpy_kernel", num_copies, start); diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/device_config_helper.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/device_config_helper.hpp index 5a63461667f..d45434efdba 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/device_config_helper.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/device_config_helper.hpp @@ -878,9 +878,9 @@ struct adjacent_difference_config_tag struct adjacent_difference_config_params { - kernel_config_params kernel_config; - ::rocprim::block_load_method block_load_method; - ::rocprim::block_store_method block_store_method; + kernel_config_params kernel_config{}; + ::rocprim::block_load_method block_load_method{}; + ::rocprim::block_store_method block_store_method{}; }; } // namespace detail @@ -919,16 +919,17 @@ namespace detail { template -struct default_adjacent_difference_config_base +constexpr adjacent_difference_config_params adjacent_difference_config_params_base() { - static constexpr unsigned int item_scale + constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div(sizeof(Value), sizeof(int)); - using type = adjacent_difference_config< - limit_block_size<256U, sizeof(Value), ROCPRIM_WARP_SIZE_64>::value, - ::rocprim::max(1u, 16u / item_scale), + return adjacent_difference_config_params{ + {limit_block_size<256U, sizeof(Value), ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 16u / item_scale)}, ::rocprim::block_load_method::block_load_transpose, - ::rocprim::block_store_method::block_store_transpose>; + ::rocprim::block_store_method::block_store_transpose + }; }; } // namespace detail @@ -942,19 +943,19 @@ struct batch_memcpy_config_tag struct batch_memcpy_config_params { /// \brief Kernel config for thread- and warp-level copy - kernel_config_params non_blev_batch_memcpy_kernel_config; + kernel_config_params non_blev_batch_memcpy_kernel_config{}; /// \brief Number of bytes per thread for thread-level copy - unsigned int tlev_items_per_thread; + unsigned int tlev_items_per_thread = 0; /// \brief Kernel config for block-level copy - kernel_config_params blev_batch_memcpy_kernel_config; + kernel_config_params blev_batch_memcpy_kernel_config{}; /// \brief Minimum size to use warp-level copy instead of thread-level - unsigned int wlev_size_threshold; + unsigned int wlev_size_threshold = 0; /// \brief Minimum size to use block-level copy instead of warp-level - unsigned int blev_size_threshold; + unsigned int blev_size_threshold = 0; }; } // namespace detail @@ -1001,19 +1002,20 @@ namespace detail { template -struct default_batch_memcpy_config_base +constexpr batch_memcpy_config_params batch_memcpy_config_params_base() { - static constexpr unsigned int item_scale + constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div(sizeof(Value), sizeof(int)); - using type - = batch_memcpy_config::value, - ::rocprim::max(1u, 16u / item_scale), - ::rocprim::max(1u, 16u / item_scale), - limit_block_size<256U, sizeof(Value), ROCPRIM_WARP_SIZE_64>::value, - ::rocprim::max(1u, 16u / item_scale), - 128, - 1024>; + return batch_memcpy_config_params{ + {limit_block_size<256U, sizeof(Value), ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 16u / item_scale)}, + ::rocprim::max(1u, 16u / item_scale), + {limit_block_size<256U, sizeof(Value), ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 16u / item_scale)}, + 128, + 1024 + }; }; } // namespace detail @@ -1372,7 +1374,7 @@ struct adjacent_find_config_tag struct adjacent_find_config_params { - kernel_config_params kernel_config; + kernel_config_params kernel_config{}; }; } // namespace detail @@ -1415,23 +1417,26 @@ namespace detail { template -struct default_find_first_of_config_base +constexpr find_first_of_config_params find_first_of_config_params_base() { - static constexpr unsigned int item_scale + constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div(sizeof(Value), sizeof(int)); - using type = find_first_of_config<256, ::rocprim::max(1u, 16u / item_scale)>; + return find_first_of_config_params{ + {256, ::rocprim::max(1u, 16u / item_scale)} + }; }; template -struct default_adjacent_find_config_base +constexpr adjacent_find_config_params adjacent_find_config_params_base() { - static constexpr unsigned int item_scale + constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div(sizeof(InputT), sizeof(int)); - using type - = adjacent_find_config::value, - ::rocprim::max(1u, 16u / item_scale)>; + return adjacent_find_config_params{ + {limit_block_size<1024U, sizeof(InputT), ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 16u / item_scale)} + }; }; } // namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_adjacent_difference.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_adjacent_difference.hpp index b8ebb7b961a..91d18c5db80 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_adjacent_difference.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_adjacent_difference.hpp @@ -77,17 +77,17 @@ hipError_t adjacent_difference_impl(void* const temporary_storage, using larger_type = std::conditional_t<(sizeof(value_type) >= sizeof(output_type)), value_type, output_type>; - using config = wrapped_adjacent_difference_config; + using Selector = adjacent_difference_config_selector; detail::target_arch target_arch; - hipError_t result = detail::host_target_arch(stream, target_arch); - if(result != hipSuccess) - { - return result; - } + ROCPRIM_RETURN_ON_ERROR(detail::host_target_arch(stream, target_arch)); + + detail::gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + + const target current_target(target_arch, target_gpu); - const detail::adjacent_difference_config_params params - = detail::dispatch_target_arch(target_arch); + const auto params = get_config(Config{}, current_target); const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_thread = params.kernel_config.items_per_thread; @@ -183,12 +183,12 @@ hipError_t adjacent_difference_impl(void* const temporary_storage, starting_block); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - adjacent_difference_kernel, - current_blocks, - block_size, - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(current_target, + adjacent_difference_kernel, + current_blocks, + block_size, + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("adjacent_difference_kernel", current_size, diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_adjacent_difference_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_adjacent_difference_config.hpp index e68391dd03f..9e06b2631a5 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_adjacent_difference_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_adjacent_difference_config.hpp @@ -43,63 +43,34 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -// Specialization for user provided configuration -template -struct wrapped_adjacent_difference_config +template +struct adjacent_difference_config_selector { - static_assert( - std::is_same::value, - "Config must be a specialization of struct template adjacent_difference_config"); + using targets = std:: + conditional_t; + using param_type = adjacent_difference_config_params; - template - struct architecture_config - { - static constexpr adjacent_difference_config_params params = AdjacentDifferenceConfig{}; - }; -}; - -// Specialization for selecting the default configuration for in place -template -struct wrapped_adjacent_difference_config -{ - template - struct architecture_config - { - static constexpr adjacent_difference_config_params params - = default_adjacent_difference_inplace_config(Arch), Value>{}; - }; -}; + param_type params; -// Specialization for selecting the default configuration for out of place -template -struct wrapped_adjacent_difference_config -{ - template - struct architecture_config + template + constexpr param_type picker_helper() { - static constexpr adjacent_difference_config_params params - = default_adjacent_difference_config(Arch), Value>{}; - }; + // Different configs if it is inplace. + if constexpr(InPlace) + { + return adjacent_difference_inplace_config_picker(); + } + else + { + return adjacent_difference_config_picker(); + } + } + + template + constexpr adjacent_difference_config_selector(Target) : params(picker_helper()) + {} }; -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template -template -constexpr adjacent_difference_config_params - wrapped_adjacent_difference_config:: - architecture_config::params; -template -template -constexpr adjacent_difference_config_params - wrapped_adjacent_difference_config::architecture_config< - Arch>::params; -template -template -constexpr adjacent_difference_config_params - wrapped_adjacent_difference_config::architecture_config< - Arch>::params; -#endif // DOXYGEN_SHOULD_SKIP_THIS - } // namespace detail END_ROCPRIM_NAMESPACE diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_adjacent_find.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_adjacent_find.hpp index 65f74291359..539696ec499 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_adjacent_find.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_adjacent_find.hpp @@ -69,8 +69,7 @@ hipError_t adjacent_find_impl(void* const temporary_storage, // Use dynamic tile id using ordered_tile_id_type = detail::ordered_block_id; - // Kernel launch config - using config = wrapped_adjacent_find_config; + using Selector = adjacent_find_config_selector; // Transform input auto wrapped_equal_op = [op, size](const wrapped_input_type& a) -> index_type @@ -138,7 +137,14 @@ hipError_t adjacent_find_impl(void* const temporary_storage, target_arch target_arch; ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); - const adjacent_find_config_params params = dispatch_target_arch(target_arch); + + gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + + const target current_target(target_arch, target_gpu); + + const auto params = get_config(Config{}, current_target); + const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_thread = params.kernel_config.items_per_thread; const unsigned int items_per_block = block_size * items_per_thread; @@ -155,7 +161,8 @@ hipError_t adjacent_find_impl(void* const temporary_storage, ordered_tile_id); }; - auto adjacent_find_block_reduce_kernel = make_launch_plan(target_arch, kernel); + auto adjacent_find_block_reduce_kernel + = make_launch_plan(current_target, kernel); // Get grid size for maximum occupancy, as we may not be able to schedule all the blocks // at the same time diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_adjacent_find_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_adjacent_find_config.hpp index 5a2990849c1..6e31a488448 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_adjacent_find_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_adjacent_find_config.hpp @@ -34,67 +34,33 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -template -struct wrapped_adjacent_find_config +template +struct adjacent_find_config_selector { - static_assert(std::is_same::value, - "Config must be a specialization of struct template adjacent_find_config"); + using targets = adjacent_find_targets; + using param_type = adjacent_find_config_params; - template - struct architecture_config - { - static constexpr adjacent_find_config_params params = Config{}; - }; -}; + param_type params; -// Generic for default config: instantiate base config. -template -struct wrapped_adjacent_find_impl -{ - template - struct architecture_config + template + constexpr param_type picker_helper() { - static constexpr adjacent_find_config_params params = - typename default_adjacent_find_config_base::type{}; - }; + // Different configs if it is inplace. + if constexpr(rocprim::is_arithmetic::value) + { + return adjacent_find_config_picker(); + } + else + { + return adjacent_find_config_params_base(); + } + } + + template + constexpr adjacent_find_config_selector(Target) : params(picker_helper()) + {} }; -// Specialization for default config if types are arithmetic or half/bfloat16-precision -// floating point types: instantiate the tuned config. -template -struct wrapped_adjacent_find_impl::value>> -{ - template - struct architecture_config - { - static constexpr adjacent_find_config_params params - = default_adjacent_find_config(Arch), Type>(); - }; -}; - -// Specialization for default config. -template -struct wrapped_adjacent_find_config : wrapped_adjacent_find_impl -{}; - -#ifndef DOXYGEN_DOCUMENTATION_BUILD -template -template -constexpr adjacent_find_config_params - wrapped_adjacent_find_config::architecture_config::params; - -template -template -constexpr adjacent_find_config_params - wrapped_adjacent_find_impl::architecture_config::params; - -template -template -constexpr adjacent_find_config_params wrapped_adjacent_find_impl< - Type, - std::enable_if_t::value>>::architecture_config::params; -#endif // DOXYGEN_DOCUMENTATION_BUILD - } // namespace detail END_ROCPRIM_NAMESPACE diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_find_first_of.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_find_first_of.hpp index 6db6fa91388..fd64363a784 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_find_first_of.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_find_first_of.hpp @@ -183,17 +183,19 @@ hipError_t find_first_of_impl(void* temporary_storage, bool debug_synchronous) { using type = typename std::iterator_traits::value_type; - using config = wrapped_find_first_of_config; + using Selector = find_first_of_config_selector; using find_first_of_kernels = find_first_of_impl_kernels; target_arch target_arch; - hipError_t result = host_target_arch(stream, target_arch); - if(result != hipSuccess) - { - return result; - } - const find_first_of_config_params params = dispatch_target_arch(target_arch); + ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); + + gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + + const target current_target(target_arch, target_gpu); + + const auto params = get_config(Config{}, current_target); const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_thread = params.kernel_config.items_per_thread; @@ -207,7 +209,7 @@ hipError_t find_first_of_impl(void* temporary_storage, ordered_bid_type::id_type* ordered_bid_storage = nullptr; // Calculate required temporary storage - result = temp_storage::partition( + hipError_t result = temp_storage::partition( temporary_storage, storage_size, temp_storage::make_linear_partition( @@ -246,7 +248,8 @@ hipError_t find_first_of_impl(void* temporary_storage, compare_function); }; - auto find_first_of_configured_kernel = make_launch_plan(target_arch, kernel); + auto find_first_of_configured_kernel + = make_launch_plan(current_target, kernel); const size_t shared_memory_size = 0; diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_find_first_of_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_find_first_of_config.hpp index 26c0a30d7a4..610389f2118 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_find_first_of_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_find_first_of_config.hpp @@ -33,40 +33,19 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -// generic struct that instantiates custom configurations -template -struct wrapped_find_first_of_config +template +struct find_first_of_config_selector { - template - struct architecture_config - { - static constexpr find_first_of_config_params params = Config{}; - }; -}; + using targets = find_first_of_targets; + using param_type = find_first_of_config_params; -// specialized for rocprim::default_config, which instantiates the default_find_first_of_config -template -struct wrapped_find_first_of_config -{ - template - struct architecture_config - { - static constexpr find_first_of_config_params params - = default_find_first_of_config(Arch), Type>(); - }; -}; + param_type params; -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template -template -constexpr find_first_of_config_params - wrapped_find_first_of_config::architecture_config::params; - -template -template -constexpr find_first_of_config_params - wrapped_find_first_of_config::architecture_config::params; -#endif // DOXYGEN_SHOULD_SKIP_THIS + template + constexpr find_first_of_config_selector(Target) + : params(find_first_of_config_picker()) + {} +}; } // namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_memcpy_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_memcpy_config.hpp index 5e8af2a3a62..75a0a1dec5d 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_memcpy_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_memcpy_config.hpp @@ -40,47 +40,45 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -// Specialization for user provided configuration -template -struct wrapped_batch_memcpy_config +template +struct non_blev_batch_memcpy_config_static_selector { - static_assert(std::is_same::value, - "Config must be a specialization of struct template batch_memcpy_config"); - - template - struct architecture_config - { - static constexpr batch_memcpy_config_params params = BatchMemcpyConfig{}; - }; + static constexpr auto block_size = target_config2::params + .non_blev_batch_memcpy_kernel_config.block_size; }; -// Specialization for selecting the default configuration for out of place -template -struct wrapped_batch_memcpy_config +template +struct blev_batch_memcpy_config_static_selector { - template - struct architecture_config - { - static constexpr batch_memcpy_config_params params - = IsMemCpy ? (batch_memcpy_config_params) - default_batch_memcpy_config(Arch), Value>{} - : (batch_memcpy_config_params) - default_batch_copy_config(Arch), Value>{}; - }; + static constexpr auto block_size = target_config2::params + .blev_batch_memcpy_kernel_config.block_size; }; -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template -template -constexpr batch_memcpy_config_params - wrapped_batch_memcpy_config::architecture_config< - Arch>::params; template -template -constexpr batch_memcpy_config_params - wrapped_batch_memcpy_config::architecture_config< - Arch>::params; -#endif // DOXYGEN_SHOULD_SKIP_THIS +struct batch_memcpy_config_selector +{ + using targets = std::conditional_t; + using param_type = batch_memcpy_config_params; + + param_type params; + + template + constexpr param_type picker_helper() + { + if constexpr(IsMemCpy) + { + return batch_memcpy_config_picker(); + } + else + { + return batch_copy_config_picker(); + } + } + + template + constexpr batch_memcpy_config_selector(Target) : params(picker_helper()) + {} +}; } // namespace detail diff --git a/projects/rocprim/test/rocprim/test_device_batch_memcpy.cpp b/projects/rocprim/test/rocprim/test_device_batch_memcpy.cpp index 48060147c94..d5009f89d97 100644 --- a/projects/rocprim/test/rocprim/test_device_batch_memcpy.cpp +++ b/projects/rocprim/test/rocprim/test_device_batch_memcpy.cpp @@ -256,15 +256,20 @@ TYPED_TEST(RocprimDeviceBatchMemcpyTests, SizeAndTypeVariation) constexpr bool use_indirect_iterator = TestFixture::use_indirect_iterator; constexpr bool debug_synchronous = TestFixture::debug_synchronous; - using config = rocprim::detail:: - wrapped_batch_memcpy_config; + using Selector = rocprim::detail::batch_memcpy_config_selector; rocprim::detail::target_arch target_arch; - hipError_t success = rocprim::detail::host_target_arch(hipStreamDefault, target_arch); + hipError_t success = host_target_arch(hipStreamDefault, target_arch); + + rocprim::detail::gpu target_gpu; + success = host_target_gpu(hipStreamDefault, target_gpu); + ASSERT_EQ(success, hipSuccess); - const rocprim::detail::batch_memcpy_config_params params - = rocprim::detail::dispatch_target_arch(target_arch); + const rocprim::detail::target get_target(target_arch, target_gpu); + + const auto params + = rocprim::detail::get_config(rocprim::default_config{}, get_target); const int32_t wlev_min_size = params.wlev_size_threshold; const int32_t blev_min_size = params.blev_size_threshold; From 78de57f3dbf4828a0463a15522227cb501eec66f Mon Sep 17 00:00:00 2001 From: Nick Breed Date: Thu, 13 Nov 2025 10:39:17 +0000 Subject: [PATCH 07/26] Resolve "Update configs for new config system part 2" --- .../include/rocprim/device/config_types.hpp | 40 +- .../device/detail/config/device_histogram.hpp | 6341 ++++++------- .../device/detail/config/device_merge.hpp | 8101 +++++++--------- .../config/device_merge_sort_block_merge.hpp | 8254 ++++++++--------- .../config/device_merge_sort_block_sort.hpp | 7487 ++++++--------- .../device/detail/device_config_helper.hpp | 126 +- .../rocprim/device/detail/device_search_n.hpp | 12 +- .../rocprim/device/device_histogram.hpp | 127 +- .../device/device_histogram_config.hpp | 56 +- .../include/rocprim/device/device_merge.hpp | 79 +- .../rocprim/device/device_merge_config.hpp | 40 +- .../rocprim/device/device_merge_sort.hpp | 235 +- .../device/device_merge_sort_config.hpp | 96 +- .../rocprim/device/device_partition.hpp | 8 +- .../rocprim/device/device_radix_sort.hpp | 18 +- .../include/rocprim/device/device_reduce.hpp | 8 +- .../rocprim/device/device_reduce_by_key.hpp | 6 +- .../device/device_run_length_encode.hpp | 6 +- .../include/rocprim/device/device_scan.hpp | 8 +- .../rocprim/device/device_scan_by_key.hpp | 6 +- .../device/device_segmented_radix_sort.hpp | 12 +- .../device/device_segmented_reduce.hpp | 6 +- .../rocprim/device/device_segmented_scan.hpp | 6 +- .../rocprim/device/device_transform.hpp | 6 +- .../device_radix_block_sort.hpp | 6 +- .../device_radix_merge_sort.hpp | 13 +- .../rocprim/test/rocprim/test_device_scan.cpp | 12 +- .../test/rocprim/test_device_search_n.cpp | 20 +- .../test/rocprim/test_linking_new_scan.hpp | 6 +- 29 files changed, 12934 insertions(+), 18207 deletions(-) diff --git a/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp b/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp index 5b9b53e488f..e812fdae740 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp @@ -256,6 +256,8 @@ constexpr std::tuple target_gpu_names[] = { std::make_tuple("RX 6900", gpu::rx6900), }; +// TODO: Remove comp_target when adopting C++20 and dropping C++17 support. +// comp_target exists, because target can not be passed as a template variable before C++20. template struct comp_target { @@ -265,6 +267,9 @@ struct comp_target static constexpr rep r = r_; }; +// Macro to have a singular place for conversion, limited by C++17. +#define TARGET_TO_COMP_TARGET(CT) comp_target<(CT).g, (CT).i, (CT).s, (CT).r> + struct target { gen g; @@ -428,41 +433,6 @@ struct default_config_selector = Config::template architecture_config::params.kernel_config.block_size; }; -template -struct histogram_config_selector -{ - static constexpr unsigned int block_size - = Config::template architecture_config::params.histogram_config.block_size; -}; - -template -struct histogram_global_config_selector -{ - static constexpr unsigned int block_size - = Config::template architecture_config::params.histogram_global_config.block_size; -}; - -template -struct merge_oddeven_config_selector -{ - static constexpr unsigned int block_size - = Config::template architecture_config::params.merge_oddeven_config.block_size; -}; - -template -struct merge_mergepath_partition_config_selector -{ - static constexpr unsigned int block_size = Config::template architecture_config::params - .merge_mergepath_partition_config.block_size; -}; - -template -struct merge_mergepath_config_selector -{ - static constexpr unsigned int block_size - = Config::template architecture_config::params.merge_mergepath_config.block_size; -}; - template struct target_config { diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_histogram.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_histogram.hpp index 7c73fde9b6e..e5f67d94e9a 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_histogram.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_histogram.hpp @@ -40,3585 +40,2768 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_histogram_config - : default_histogram_config_base::type -{}; - -// Based on value_type = double, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = double, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = double, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = double, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = double, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 1) && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 2) && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 3) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 4) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 4) && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::int128_t, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 2> -{}; - -// Based on value_type = rocprim::int128_t, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::int128_t, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = rocprim::int128_t, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = rocprim::int128_t, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int64_t, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = int64_t, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = int64_t, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int64_t, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int64_t, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = int, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = int, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 1) && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 2) && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 3) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 4) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx1030), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 4) && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = double, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 1024, 2048, 3> -{}; - -// Based on value_type = double, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = double, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = double, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 2> -{}; - -// Based on value_type = double, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 1024, 2048, 3> -{}; - -// Based on value_type = float, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = float, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 1024, 2048, 3> -{}; - -// Based on value_type = float, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = float, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 2> -{}; - -// Based on value_type = rocprim::half, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 1) && (active_channels == 1))>> - : histogram_config, 1024, 2048, 3> -{}; - -// Based on value_type = rocprim::half, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 2) && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 3) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 4) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 4) && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::int128_t, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 1024, 2048, 3> -{}; - -// Based on value_type = rocprim::int128_t, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = rocprim::int128_t, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 2> -{}; - -// Based on value_type = rocprim::int128_t, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 2> -{}; - -// Based on value_type = rocprim::int128_t, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 2> -{}; - -// Based on value_type = int64_t, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 1024, 2048, 3> -{}; - -// Based on value_type = int64_t, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = int64_t, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = int64_t, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 2> -{}; - -// Based on value_type = int64_t, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 1024, 2048, 3> -{}; - -// Based on value_type = int, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = int, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 1024, 2048, 3> -{}; - -// Based on value_type = short, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 1024, 2048, 3> -{}; - -// Based on value_type = short, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 1) && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 2) && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 3) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 4) && (active_channels == 3))>> - : histogram_config, 1024, 2048, 3> -{}; - -// Based on value_type = int8_t, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx1100), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 4) && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = double, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = double, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = double, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = double, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = double, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 2> -{}; - -// Based on value_type = float, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 1) && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 2) && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 3) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 4) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 4) && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::int128_t, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 1024, 2048, 3> -{}; - -// Based on value_type = rocprim::int128_t, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = rocprim::int128_t, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::int128_t, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 1024, 2048, 3> -{}; - -// Based on value_type = rocprim::int128_t, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = int64_t, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int64_t, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int64_t, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int64_t, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 2> -{}; - -// Based on value_type = int64_t, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 1) && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 2) && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 3) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 4) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx906), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 4) && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = double, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = double, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = double, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = double, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = double, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 1) && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 2) && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 3) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 4) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 4) && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::int128_t, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 1024, 2048, 3> -{}; - -// Based on value_type = rocprim::int128_t, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 2> -{}; - -// Based on value_type = rocprim::int128_t, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::int128_t, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 1024, 2048, 3> -{}; - -// Based on value_type = rocprim::int128_t, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = int64_t, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int64_t, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 2> -{}; - -// Based on value_type = int64_t, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 2> -{}; - -// Based on value_type = int64_t, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 2> -{}; - -// Based on value_type = int64_t, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 2> -{}; - -// Based on value_type = int, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 1) && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 2) && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 3) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 4) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx908), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 4) && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = double, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = double, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = double, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = double, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = double, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 1) && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 2) && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 3) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 4) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 4) && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::int128_t, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 1024, 2048, 3> -{}; - -// Based on value_type = rocprim::int128_t, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 2> -{}; - -// Based on value_type = rocprim::int128_t, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::int128_t, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 1024, 2048, 3> -{}; - -// Based on value_type = rocprim::int128_t, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int64_t, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = int64_t, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 2> -{}; - -// Based on value_type = int64_t, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int64_t, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = int64_t, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 1) && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 2) && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 3) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 4) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx90a), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 4) && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = double, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = double, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = double, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = double, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = double, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = float, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 1) && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 2) && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 3) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 4) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::half, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 4) && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::int128_t, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 1024, 2048, 3> -{}; - -// Based on value_type = rocprim::int128_t, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 2> -{}; - -// Based on value_type = rocprim::int128_t, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = rocprim::int128_t, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 1024, 2048, 3> -{}; - -// Based on value_type = rocprim::int128_t, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 3> -{}; - -// Based on value_type = int64_t, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int64_t, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 2> -{}; - -// Based on value_type = int64_t, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 2> -{}; - -// Based on value_type = int64_t, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 2> -{}; - -// Based on value_type = int64_t, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 2> -{}; - -// Based on value_type = int, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = short, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 1) && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 2) && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 3) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 4) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = int8_t, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::unknown), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 4) && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4> -{}; - -// Based on value_type = double, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 3, kernel_config<1024, 4>> -{}; - -// Based on value_type = double, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = double, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = double, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = double, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = float, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = float, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = float, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = float, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = float, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = rocprim::half, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 1) && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = rocprim::half, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 2) && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = rocprim::half, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 3) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = rocprim::half, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 4) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = rocprim::half, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (channels == 4) && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = rocprim::int128_t, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 2, kernel_config<1024, 4>> -{}; - -// Based on value_type = rocprim::int128_t, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 2, kernel_config<1024, 4>> -{}; - -// Based on value_type = rocprim::int128_t, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = rocprim::int128_t, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 3, kernel_config<1024, 4>> -{}; - -// Based on value_type = rocprim::int128_t, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 3, kernel_config<1024, 4>> -{}; - -// Based on value_type = int64_t, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 3, kernel_config<1024, 4>> -{}; - -// Based on value_type = int64_t, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 3, kernel_config<1024, 4>> -{}; - -// Based on value_type = int64_t, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = int64_t, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = int64_t, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = int, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = int, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = int, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = int, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = int, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = short, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 1) - && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = short, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 2) - && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = short, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 3) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = short, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 4) - && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = short, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1) && (channels == 4) - && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = int8_t, channels = 1, active_channels = 1 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 1) && (active_channels == 1))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = int8_t, channels = 2, active_channels = 2 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 2) && (active_channels == 2))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = int8_t, channels = 3, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 3) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = int8_t, channels = 4, active_channels = 3 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 4) && (active_channels == 3))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; - -// Based on value_type = int8_t, channels = 4, active_channels = 4 -template -struct default_histogram_config< - static_cast(target_arch::gfx942), - value_type, - channels, - active_channels, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) - && (sizeof(value_type) <= 1) && (channels == 4) && (active_channels == 4))>> - : histogram_config, 2048, 2048, 4, kernel_config<1024, 4>> -{}; +template +constexpr auto histogram_config_picker() -> std::enable_if_t< + std::is_same>::value, + histogram_config_params> +{ + // Based on value_type = double, channels = 1, active_channels = 1 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 5}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = double, channels = 2, active_channels = 2 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = double, channels = 3, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 4}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = double, channels = 4, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = double, channels = 4, active_channels = 4 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = float, channels = 1, active_channels = 1 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 5}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = float, channels = 2, active_channels = 2 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{128, 3}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = float, channels = 3, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 5}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = float, channels = 4, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{64, 3}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = float, channels = 4, active_channels = 4 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 1, active_channels = 1 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 8}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 2, active_channels = 2 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 3, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 4, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 4, active_channels = 4 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::int128_t, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{128, 5}, + 2048, + 2048, + 2 + }; + } + // Based on value_type = rocprim::int128_t, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{128, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::int128_t, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = rocprim::int128_t, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = rocprim::int128_t, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{64, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int64_t, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = int64_t, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = int64_t, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{64, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int64_t, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int64_t, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = int, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 9}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{128, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{64, 2}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = int, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 9}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{128, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 3}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{128, 15}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{128, 8}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 3}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Default case if none of the conditions match + return histogram_config_params_base(); +} + +template +constexpr auto histogram_config_picker() -> std::enable_if_t< + std::is_same>::value, + histogram_config_params> +{ + // Based on value_type = double, channels = 1, active_channels = 1 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 1024, + 2048, + 3 + }; + } + // Based on value_type = double, channels = 2, active_channels = 2 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = double, channels = 3, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{64, 3}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = double, channels = 4, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{64, 2}, + 2048, + 2048, + 2 + }; + } + // Based on value_type = double, channels = 4, active_channels = 4 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = float, channels = 1, active_channels = 1 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 8}, + 1024, + 2048, + 3 + }; + } + // Based on value_type = float, channels = 2, active_channels = 2 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 6}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = float, channels = 3, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 1024, + 2048, + 3 + }; + } + // Based on value_type = float, channels = 4, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 3}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = float, channels = 4, active_channels = 4 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 2 + }; + } + // Based on value_type = rocprim::half, channels = 1, active_channels = 1 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 8}, + 1024, + 2048, + 3 + }; + } + // Based on value_type = rocprim::half, channels = 2, active_channels = 2 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 3, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 5}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 4, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 3}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 4, active_channels = 4 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::int128_t, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 1024, + 2048, + 3 + }; + } + // Based on value_type = rocprim::int128_t, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = rocprim::int128_t, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{64, 2}, + 2048, + 2048, + 2 + }; + } + // Based on value_type = rocprim::int128_t, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 2 + }; + } + // Based on value_type = rocprim::int128_t, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 2 + }; + } + // Based on value_type = int64_t, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 1024, + 2048, + 3 + }; + } + // Based on value_type = int64_t, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = int64_t, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{64, 1}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = int64_t, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 2 + }; + } + // Based on value_type = int64_t, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 8}, + 1024, + 2048, + 3 + }; + } + // Based on value_type = int, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = int, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{64, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 8}, + 1024, + 2048, + 3 + }; + } + // Based on value_type = short, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 1024, + 2048, + 3 + }; + } + // Based on value_type = short, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{128, 12}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{128, 6}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 1024, + 2048, + 3 + }; + } + // Based on value_type = int8_t, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{128, 3}, + 2048, + 2048, + 4 + }; + } + // Default case if none of the conditions match + return histogram_config_params_base(); +} + +template +constexpr auto histogram_config_picker() -> std::enable_if_t< + std::is_same>::value, + histogram_config_params> +{ + // Based on value_type = double, channels = 1, active_channels = 1 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = double, channels = 2, active_channels = 2 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 6}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = double, channels = 3, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = double, channels = 4, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{64, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = double, channels = 4, active_channels = 4 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 2 + }; + } + // Based on value_type = float, channels = 1, active_channels = 1 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 14}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = float, channels = 2, active_channels = 2 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 7}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = float, channels = 3, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = float, channels = 4, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = float, channels = 4, active_channels = 4 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 1, active_channels = 1 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 15}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 2, active_channels = 2 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 8}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 3, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 5}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 4, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 4, active_channels = 4 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::int128_t, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 1024, + 2048, + 3 + }; + } + // Based on value_type = rocprim::int128_t, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 3}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = rocprim::int128_t, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{64, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::int128_t, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 1024, + 2048, + 3 + }; + } + // Based on value_type = rocprim::int128_t, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = int64_t, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int64_t, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int64_t, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int64_t, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 2 + }; + } + // Based on value_type = int64_t, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 14}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 7}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 16}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{64, 10}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 5}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Default case if none of the conditions match + return histogram_config_params_base(); +} + +template +constexpr auto histogram_config_picker() -> std::enable_if_t< + std::is_same>::value, + histogram_config_params> +{ + // Based on value_type = double, channels = 1, active_channels = 1 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = double, channels = 2, active_channels = 2 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 6}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = double, channels = 3, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = double, channels = 4, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = double, channels = 4, active_channels = 4 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = float, channels = 1, active_channels = 1 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 13}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = float, channels = 2, active_channels = 2 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 7}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = float, channels = 3, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = float, channels = 4, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = float, channels = 4, active_channels = 4 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 1, active_channels = 1 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 15}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 2, active_channels = 2 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 8}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 3, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 3}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 4, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 4, active_channels = 4 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::int128_t, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 1024, + 2048, + 3 + }; + } + // Based on value_type = rocprim::int128_t, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 3}, + 2048, + 2048, + 2 + }; + } + // Based on value_type = rocprim::int128_t, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::int128_t, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 1024, + 2048, + 3 + }; + } + // Based on value_type = rocprim::int128_t, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = int64_t, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int64_t, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 5}, + 2048, + 2048, + 2 + }; + } + // Based on value_type = int64_t, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 2 + }; + } + // Based on value_type = int64_t, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 2 + }; + } + // Based on value_type = int64_t, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 2 + }; + } + // Based on value_type = int, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 15}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 7}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 16}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{128, 3}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 16}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 8}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 5}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Default case if none of the conditions match + return histogram_config_params_base(); +} + +template +constexpr auto histogram_config_picker() -> std::enable_if_t< + std::is_same>::value, + histogram_config_params> +{ + // Based on value_type = double, channels = 1, active_channels = 1 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 10}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = double, channels = 2, active_channels = 2 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = double, channels = 3, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = double, channels = 4, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = double, channels = 4, active_channels = 4 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = float, channels = 1, active_channels = 1 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 10}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = float, channels = 2, active_channels = 2 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{128, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = float, channels = 3, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = float, channels = 4, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = float, channels = 4, active_channels = 4 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 1, active_channels = 1 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 5}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 2, active_channels = 2 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 8}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 3, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 4, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::half, channels = 4, active_channels = 4 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::int128_t, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 1024, + 2048, + 3 + }; + } + // Based on value_type = rocprim::int128_t, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 2 + }; + } + // Based on value_type = rocprim::int128_t, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = rocprim::int128_t, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 1024, + 2048, + 3 + }; + } + // Based on value_type = rocprim::int128_t, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int64_t, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = int64_t, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{64, 8}, + 2048, + 2048, + 2 + }; + } + // Based on value_type = int64_t, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int64_t, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 3 + }; + } + // Based on value_type = int64_t, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 3}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 10}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{128, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{128, 1}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{256, 12}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 5}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 5}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 2}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = short, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{128, 15}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{256, 8}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 5}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Based on value_type = int8_t, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{256, 4}, + 2048, + 2048, + 4 + }; + } + // Default case if none of the conditions match + return histogram_config_params_base(); +} + +template +constexpr auto histogram_config_picker() -> std::enable_if_t< + std::is_same>::value, + histogram_config_params> +{ + // Based on value_type = double, channels = 1, active_channels = 1 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{ 256, 16}, + 2048, + 2048, + 3, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = double, channels = 2, active_channels = 2 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{ 128, 2}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = double, channels = 3, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{ 256, 5}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = double, channels = 4, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{ 128, 1}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = double, channels = 4, active_channels = 4 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{ 256, 1}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = float, channels = 1, active_channels = 1 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{ 256, 15}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = float, channels = 2, active_channels = 2 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{ 128, 3}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = float, channels = 3, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{ 256, 5}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = float, channels = 4, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{ 256, 4}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = float, channels = 4, active_channels = 4 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{ 256, 3}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = rocprim::half, channels = 1, active_channels = 1 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{ 256, 15}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = rocprim::half, channels = 2, active_channels = 2 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{ 256, 8}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = rocprim::half, channels = 3, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{ 256, 4}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = rocprim::half, channels = 4, active_channels = 3 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{ 256, 4}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = rocprim::half, channels = 4, active_channels = 4 + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{ 256, 4}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = rocprim::int128_t, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{ 256, 4}, + 2048, + 2048, + 2, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = rocprim::int128_t, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{ 256, 2}, + 2048, + 2048, + 2, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = rocprim::int128_t, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{ 256, 2}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = rocprim::int128_t, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{ 256, 2}, + 2048, + 2048, + 3, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = rocprim::int128_t, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{ 256, 1}, + 2048, + 2048, + 3, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = int64_t, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{ 256, 6}, + 2048, + 2048, + 3, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = int64_t, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{ 256, 3}, + 2048, + 2048, + 3, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = int64_t, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{ 256, 3}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = int64_t, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{ 256, 1}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = int64_t, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{ 256, 3}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = int, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{ 256, 9}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = int, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{ 128, 6}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = int, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{ 256, 2}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = int, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{ 256, 4}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = int, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{ 256, 2}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = short, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{ 256, 16}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = short, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{ 256, 8}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = short, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{ 256, 5}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = short, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{ 256, 4}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = short, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1) && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{ 256, 4}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = int8_t, channels = 1, active_channels = 1 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 1) && (active_channels == 1))) + { + return histogram_config_params{ + kernel_config_params{ 256, 12}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = int8_t, channels = 2, active_channels = 2 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 2) && (active_channels == 2))) + { + return histogram_config_params{ + kernel_config_params{ 256, 8}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = int8_t, channels = 3, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 3) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{ 256, 5}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = int8_t, channels = 4, active_channels = 3 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 4) && (active_channels == 3))) + { + return histogram_config_params{ + kernel_config_params{ 256, 4}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Based on value_type = int8_t, channels = 4, active_channels = 4 + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(value_type) <= 1) + && (channels == 4) && (active_channels == 4))) + { + return histogram_config_params{ + kernel_config_params{ 256, 4}, + 2048, + 2048, + 4, + kernel_config_params{1024, 4} + }; + } + // Default case if none of the conditions match + return histogram_config_params_base(); +} + +template +constexpr auto histogram_config_picker() -> std::enable_if_t< + std::is_same>::value, + histogram_config_params> +{ + return histogram_config_picker< + comp_target, + value_type, + channels, + active_channels>(); +} + +// All the existing configs should be auto generated +using histogram_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_merge.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_merge.hpp index a970764bb5c..2d73e9ba88f 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_merge.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_merge.hpp @@ -40,4743 +40,3370 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_merge_config : default_merge_config_base::type -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<256, 16> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<256, 16> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<1024, 4> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<512, 4> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_config<1024, 4> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 4> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<256, 10> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<1024, 8> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 10> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 2> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_config<1024, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_config<256, 5> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_config<256, 10> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_config<256, 10> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 10> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_config<1024, 1> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_config<256, 11> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<256, 16> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<1024, 4> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<1024, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<512, 4> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_config<256, 16> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 4> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<256, 10> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<1024, 8> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 10> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 2> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_config<1024, 8> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<1024, 8> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<1024, 8> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 10> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 1> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_config<256, 11> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_config<1024, 4> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_config<1024, 4> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_config<256, 10> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 10> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_config<1024, 1> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_config<1024, 16> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<256, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<1024, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<1024, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_config<256, 8> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 8> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 8> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<1024, 4> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<512, 4> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_config<1024, 4> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 4> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<1024, 4> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<1024, 8> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<1024, 8> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 2> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_config<1024, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_config<1024, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_config<1024, 8> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_config<256, 10> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 10> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_config<32, 1> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_config<512, 16> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 8> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 8> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<1024, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<512, 4> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_config<1024, 4> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 4> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<1024, 4> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<1024, 8> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<1024, 8> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 2> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_config<1024, 8> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<512, 8> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<1024, 8> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<1024, 8> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 16> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<512, 4> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_config<512, 16> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_config<512, 8> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_config<1024, 8> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_config<512, 8> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 8> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_config<512, 16> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<256, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<256, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<1024, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<1024, 4> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<1024, 4> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<1024, 4> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_config<1024, 4> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 4> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<1024, 4> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<1024, 8> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<1024, 8> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_config<1024, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_config<1024, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_config<1024, 8> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_config<512, 16> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 16> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_config<512, 16> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<1024, 4> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<1024, 4> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<1024, 4> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_config<1024, 4> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 4> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<1024, 4> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<1024, 8> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<1024, 8> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_config<1024, 8> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<512, 8> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 16> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 16> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 16> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_config<512, 16> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_config<512, 8> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_config<1024, 8> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_config<512, 16> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 16> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_config<512, 16> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 8> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 8> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 8> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_config<512, 8> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 4> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<1024, 4> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 16> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 16> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_config<1024, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_config<1024, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_config<1024, 8> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_config<512, 16> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 16> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_config<512, 16> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 8> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 8> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 8> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_config<512, 8> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 4> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<1024, 4> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 16> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 16> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_config<1024, 8> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<512, 8> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 16> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 16> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 16> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_config<512, 16> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_config<1024, 4> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_config<512, 8> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_config<512, 16> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 16> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_config<512, 16> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<256, 5> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<256, 5> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 5> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 1> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_config<256, 7> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 2> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 4> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 4> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 2> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_config<256, 10> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_config<512, 2> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_config<256, 11> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<256, 5> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<256, 5> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 5> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 1> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_config<256, 7> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 2> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 4> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 4> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 2> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_config<256, 10> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<512, 2> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 2> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 4> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 10> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 2> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_config<256, 8> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_config<512, 2> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_config<256, 11> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_config<512, 2> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 2> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<256, 5> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 5> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 1> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_config<256, 7> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 2> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 4> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 4> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 2> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_config<256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_config<256, 8> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 2> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<256, 5> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 5> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 1> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_config<256, 7> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 2> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 4> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 4> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 2> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_config<256, 10> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 2> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 4> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 4> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 2> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_config<512, 8> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_config<512, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_config<512, 2> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<512, 2> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<256, 5> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<256, 5> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 5> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 1> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_config<256, 4> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 4> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 4> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 4> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 1> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_config<256, 10> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_config<512, 2> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 10> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_config<1024, 1> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_config<256, 11> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<512, 2> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<256, 5> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<256, 5> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 5> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 1> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_config<256, 4> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<1024, 2> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 4> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 4> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 1> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_config<256, 10> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<512, 2> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 4> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 4> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 10> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 2> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_config<256, 16> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 10> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_config<1024, 1> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_config<256, 16> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<512, 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_config<512, 2> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 2> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<256, 5> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 5> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 1> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_config<256, 7> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 2> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 4> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 4> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 2> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_config<256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_config<256, 8> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 2> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<256, 5> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 5> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 1> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_config<256, 7> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 2> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 4> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 4> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 2> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_config<256, 10> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<1024, 2> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 2> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 4> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 4> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 2> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_config<512, 8> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_config<512, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_config<512, 2> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<256, 2> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<256, 4> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<256, 5> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<256, 5> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 5> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_config<256, 4> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<512, 2> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<512, 4> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 4> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<256, 8> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 10> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_config<256, 10> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_config<512, 1> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_config<512, 2> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 10> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_config<256, 11> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<256, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<256, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<256, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<256, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_config<256, 2> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<256, 2> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<256, 5> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<256, 5> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 5> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_config<256, 4> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<512, 2> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<512, 4> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 4> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<256, 8> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 10> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_config<256, 10> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<512, 2> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<512, 4> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 4> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<256, 10> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<256, 10> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_config<256, 16> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_config<1024, 1> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 8> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_config<256, 16> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<512, 4> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 2> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_config<1024, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_config<512, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_config<512, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_config<512, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_config<512, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_config<512, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<512, 4> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 2> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_config<1024, 2> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_config<32, 1> -{}; +template +constexpr auto merge_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_config_params> +{ + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {256, 16} + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {256, 16} + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 1} + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 11} + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {256, 16} + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 16} + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 1} + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 11} + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 1} + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_config_params{ + {1024, 16} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {256, 8} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 1} + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 8} + }; + } + // Default case if none of the conditions match + return merge_config_params_base(); +} + +template +constexpr auto merge_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_config_params> +{ + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {32, 1} + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {256, 8} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {256, 8} + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 1} + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Default case if none of the conditions match + return merge_config_params_base(); +} + +template +constexpr auto merge_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_config_params> +{ + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_config_params{ + {512, 16} + }; + } + // Default case if none of the conditions match + return merge_config_params_base(); +} + +template +constexpr auto merge_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_config_params> +{ + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_config_params{ + {1024, 8} + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 4} + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_config_params{ + {512, 16} + }; + } + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 1} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {32, 1} + }; + } + // Default case if none of the conditions match + return merge_config_params_base(); +} + +template +constexpr auto merge_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_config_params> +{ + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 1} + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 7} + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 11} + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 1} + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 7} + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 8} + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 11} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 1} + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return merge_config_params{ + {512, 2} + }; + } + // Default case if none of the conditions match + return merge_config_params_base(); +} + +template +constexpr auto merge_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_config_params> +{ + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 1} + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 7} + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {256, 4} + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 4} + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 8} + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 1} + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 7} + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return merge_config_params{ + {512, 2} + }; + } + // Default case if none of the conditions match + return merge_config_params_base(); +} + +template +constexpr auto merge_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_config_params> +{ + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 1} + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 4} + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 1} + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 1} + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 11} + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 1} + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 4} + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 1} + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 16} + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 1} + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 16} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {512, 1} + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return merge_config_params{ + {512, 2} + }; + } + // Default case if none of the conditions match + return merge_config_params_base(); +} + +template +constexpr auto merge_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {256, 2} + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {256, 4} + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 4} + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {256, 8} + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {512, 1} + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 11} + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {256, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {256, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {256, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {256, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 2} + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {256, 2} + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {256, 4} + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 5} + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 4} + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {256, 8} + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {512, 2} + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {256, 10} + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 16} + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_config_params{ + {1024, 1} + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_config_params{ + {1024, 2} + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_config_params{ + {512, 4} + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_config_params{ + {512, 8} + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_config_params{ + {256, 16} + }; + } + // Default case if none of the conditions match + return merge_config_params_base(); +} + +template +constexpr auto merge_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_config_params> +{ + return merge_config_picker< + comp_target, + key_type, + value_type>(); +} + +// All the existing configs should be auto generated +using merge_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_merge_sort_block_merge.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_merge_sort_block_merge.hpp index d07813ba5bd..ef6c43ecf98 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_merge_sort_block_merge.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_merge_sort_block_merge.hpp @@ -40,4542 +40,3724 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_merge_sort_block_merge_config - : merge_sort_block_merge_config_base::type -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 1> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 4> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 1> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 2> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 2> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 1024, 1> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 512, 2> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 1> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 1> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 256, 4> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_sort_block_merge_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, 128, 128, 8> -{}; +template +constexpr auto merge_sort_block_merge_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_sort_block_merge_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 1} + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 4} + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Default case if none of the conditions match + return merge_sort_block_merge_config_params_base(); +} + +template +constexpr auto merge_sort_block_merge_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_sort_block_merge_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 1} + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 2} + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 2} + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Default case if none of the conditions match + return merge_sort_block_merge_config_params_base(); +} + +template +constexpr auto merge_sort_block_merge_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_sort_block_merge_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Default case if none of the conditions match + return merge_sort_block_merge_config_params_base(); +} + +template +constexpr auto merge_sort_block_merge_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_sort_block_merge_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Default case if none of the conditions match + return merge_sort_block_merge_config_params_base(); +} + +template +constexpr auto merge_sort_block_merge_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_sort_block_merge_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Default case if none of the conditions match + return merge_sort_block_merge_config_params_base(); +} + +template +constexpr auto merge_sort_block_merge_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_sort_block_merge_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 4} + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Default case if none of the conditions match + return merge_sort_block_merge_config_params_base(); +} + +template +constexpr auto merge_sort_block_merge_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_sort_block_merge_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {1024, 1} + }; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {512, 2} + }; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 1} + }; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 1} + }; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {256, 4} + }; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_sort_block_merge_config_params{ + {256, 1, (1 << 17) + 70000}, + {128, 1}, + {128, 8} + }; + } + // Default case if none of the conditions match + return merge_sort_block_merge_config_params_base(); +} + +template +constexpr auto merge_sort_block_merge_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_sort_block_merge_config_params> +{ + return merge_sort_block_merge_config_picker< + comp_target, + key_type, + value_type>(); +} + +// All the existing configs should be auto generated +using merge_sort_block_merge_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_merge_sort_block_sort.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_merge_sort_block_sort.hpp index aabb60332e8..135176444d5 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_merge_sort_block_sort.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_merge_sort_block_sort.hpp @@ -40,4824 +40,2675 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { - -template -struct default_merge_sort_block_sort_config - : merge_sort_block_sort_config_base::type -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 32> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 32> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<1024, 4> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 16> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<1024, 4> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 16> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_sort_config<1024, 8> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<1024, 4> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1100), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 32> -{}; - -// Based on key_type = double, value_type = custom_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<1024, 4> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<1024, 4> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = float, value_type = custom_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 16> -{}; - -// Based on key_type = rocprim::half, value_type = custom_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<1024, 8> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 32> -{}; - -// Based on key_type = int64_t, value_type = custom_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<1024, 4> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<1024, 4> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = int, value_type = custom_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 16> -{}; - -// Based on key_type = short, value_type = custom_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<1024, 8> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<1024, 8> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 32> -{}; - -// Based on key_type = int8_t, value_type = custom_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1200), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 32> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<1024, 4> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 16> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<512, 16> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<1024, 4> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 16> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<1024, 8> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 32> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 8> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 32> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 16> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<1024, 8> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 16> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx906), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 16> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 16> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<1024, 8> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 16> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<1024, 8> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_sort_config<1024, 8> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx908), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 16> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 16> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 16> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 16> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx90a), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 32> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 16> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<1024, 8> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 16> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<1024, 8> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_sort_config<1024, 8> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::unknown), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 16> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = double, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = double, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = double, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = double, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = double, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = float, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = float, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<1024, 4> -{}; - -// Based on key_type = float, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<1024, 4> -{}; - -// Based on key_type = float, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = float, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::half, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::half, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_sort_config<1024, 8> -{}; - -// Based on key_type = rocprim::half, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = rocprim::half, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 16> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int64_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int64_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = int, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = short, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = short, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<256, 16> -{}; - -// Based on key_type = short, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = short, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 32> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int8_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))>> - : merge_sort_block_sort_config<256, 16> -{}; - -// Based on key_type = int8_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<1024, 8> -{}; - -// Based on key_type = int8_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx942), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 32> -{}; - -// Based on key_type = double, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = float, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = rocprim::half, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int64_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) - && (sizeof(value_type) > 4))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) - && (sizeof(value_type) > 2))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = short -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) - && (sizeof(value_type) > 1))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = int8_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) - && (!std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = rocprim::int128_t, value_type = empty_type -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) - && (sizeof(key_type) > 8) - && (std::is_same::value))>> - : merge_sort_block_sort_config<256, 4> -{}; - -// Based on key_type = int64_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) - && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<256, 8> -{}; - -// Based on key_type = short, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) - && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> : merge_sort_block_sort_config<512, 4> -{}; - -// Based on key_type = int8_t, value_type = rocprim::int128_t -template -struct default_merge_sort_block_sort_config< - static_cast(target_arch::gfx1201), - key_type, - value_type, - std::enable_if_t<(!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) - && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))>> - : merge_sort_block_sort_config<512, 4> -{}; +template +constexpr auto merge_sort_block_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_sort_block_sort_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 32}; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 32}; + } + // Default case if none of the conditions match + return merge_sort_block_sort_config_params_base(); +} + +template +constexpr auto merge_sort_block_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_sort_block_sort_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{1024, 4}; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 16}; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{1024, 4}; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 16}; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{1024, 8}; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{1024, 4}; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 32}; + } + // Default case if none of the conditions match + return merge_sort_block_sort_config_params_base(); +} + +template +constexpr auto merge_sort_block_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_sort_block_sort_config_params> +{ + // Based on key_type = double, value_type = custom_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{1024, 4}; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{1024, 4}; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = float, value_type = custom_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 16}; + } + // Based on key_type = rocprim::half, value_type = custom_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{1024, 8}; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 32}; + } + // Based on key_type = int64_t, value_type = custom_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{1024, 4}; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{1024, 4}; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = int, value_type = custom_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 16}; + } + // Based on key_type = short, value_type = custom_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{1024, 8}; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{1024, 8}; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 32}; + } + // Based on key_type = int8_t, value_type = custom_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 32}; + } + // Default case if none of the conditions match + return merge_sort_block_sort_config_params_base(); +} + +template +constexpr auto merge_sort_block_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_sort_block_sort_config_params> +{ + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{1024, 4}; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 16}; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 16}; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{1024, 4}; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 16}; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{1024, 8}; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 32}; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 8}; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 32}; + } + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Default case if none of the conditions match + return merge_sort_block_sort_config_params_base(); +} + +template +constexpr auto merge_sort_block_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_sort_block_sort_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 16}; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{1024, 8}; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 16}; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 16}; + } + // Default case if none of the conditions match + return merge_sort_block_sort_config_params_base(); +} + +template +constexpr auto merge_sort_block_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_sort_block_sort_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 16}; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{1024, 8}; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 16}; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{1024, 8}; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{1024, 8}; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 16}; + } + // Default case if none of the conditions match + return merge_sort_block_sort_config_params_base(); +} + +template +constexpr auto merge_sort_block_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_sort_block_sort_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 16}; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 16}; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 16}; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 32}; + } + // Default case if none of the conditions match + return merge_sort_block_sort_config_params_base(); +} + +template +constexpr auto merge_sort_block_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_sort_block_sort_config_params> +{ + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = double, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = double, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = double, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = double, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = double, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = float, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = float, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = float, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{1024, 4}; + } + // Based on key_type = float, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{1024, 4}; + } + // Based on key_type = float, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = float, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = rocprim::half, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::half, value_type = int64_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::half, value_type = int + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = rocprim::half, value_type = short + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{1024, 8}; + } + // Based on key_type = rocprim::half, value_type = int8_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = rocprim::half, value_type = empty_type + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 16}; + } + // Based on key_type = rocprim::int128_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = rocprim::int128_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 16) + && (sizeof(key_type) > 8) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int64_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int64_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int64_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = int, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 4) + && (sizeof(key_type) > 2) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = short, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = short, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 8) + && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = short, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 4) + && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{256, 8}; + } + // Based on key_type = short, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 2) + && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 16}; + } + // Based on key_type = short, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = short, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 2) + && (sizeof(key_type) > 1) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 32}; + } + // Based on key_type = int8_t, value_type = rocprim::int128_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int8_t, value_type = int64_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) + { + return merge_sort_block_sort_config_params{256, 4}; + } + // Based on key_type = int8_t, value_type = int + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) + { + return merge_sort_block_sort_config_params{512, 4}; + } + // Based on key_type = int8_t, value_type = short + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) + { + return merge_sort_block_sort_config_params{256, 16}; + } + // Based on key_type = int8_t, value_type = int8_t + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (sizeof(value_type) <= 1) + && (!std::is_same::value))) + { + return merge_sort_block_sort_config_params{1024, 8}; + } + // Based on key_type = int8_t, value_type = empty_type + if constexpr((!bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 1) + && (std::is_same::value))) + { + return merge_sort_block_sort_config_params{256, 32}; + } + // Default case if none of the conditions match + return merge_sort_block_sort_config_params_base(); +} + +template +constexpr auto merge_sort_block_sort_config_picker() -> std::enable_if_t< + std::is_same>::value, + merge_sort_block_sort_config_params> +{ + return merge_sort_block_sort_config_picker< + comp_target, + key_type, + value_type>(); +} + +// All the existing configs should be auto generated +using merge_sort_block_sort_targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/device_config_helper.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/device_config_helper.hpp index d45434efdba..36c8cd23414 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/device_config_helper.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/device_config_helper.hpp @@ -96,20 +96,20 @@ constexpr unsigned int merge_sort_block_size(const unsigned int item_scale) // Calculate kernel configurations, such that it will not exceed shared memory maximum template -struct merge_sort_block_sort_config_base +constexpr merge_sort_block_sort_config_params merge_sort_block_sort_config_params_base() { - static constexpr unsigned int item_scale = ::rocprim::max(sizeof(Key), sizeof(Value)); - static constexpr bool use_fallback = merge_sort_block_size(item_scale) * 2 - * merge_sort_items_per_thread(item_scale) * item_scale - <= max_smem_per_block; + constexpr unsigned int item_scale = ::rocprim::max(sizeof(Key), sizeof(Value)); + constexpr bool use_fallback = merge_sort_block_size(item_scale) * 2 + * merge_sort_items_per_thread(item_scale) * item_scale + <= max_smem_per_block; // multiply by 2 to ensure block_sort's items_per_block >= block_merge's items_per_block - static constexpr unsigned int block_size - = use_fallback ? merge_sort_block_size(item_scale) * 2 : 256; - static constexpr unsigned int items_per_thread + constexpr unsigned int block_size = use_fallback ? merge_sort_block_size(item_scale) * 2 : 256; + constexpr unsigned int items_per_thread = use_fallback ? merge_sort_items_per_thread(item_scale) : 1; - using type = merge_sort_block_sort_config; + + return merge_sort_block_sort_config_params{ + {block_size, items_per_thread} + }; }; // Calculate kernel configurations, such that it will not exceed shared memory maximum @@ -160,23 +160,21 @@ struct merge_sort_block_merge_config : rocprim::detail::merge_sort_block_merge_c }; template -struct merge_sort_block_merge_config_base -{ - static constexpr unsigned int item_scale = ::rocprim::max(sizeof(Key), sizeof(Value)); - static constexpr bool use_fallback = merge_sort_block_size(item_scale) * 2 - * merge_sort_items_per_thread(item_scale) * item_scale - <= max_smem_per_block; - static constexpr unsigned int block_size - = use_fallback ? merge_sort_block_size(item_scale) : 128; - static constexpr unsigned int items_per_thread +constexpr merge_sort_block_merge_config_params merge_sort_block_merge_config_params_base() +{ + constexpr unsigned int item_scale = ::rocprim::max(sizeof(Key), sizeof(Value)); + constexpr bool use_fallback = merge_sort_block_size(item_scale) * 2 + * merge_sort_items_per_thread(item_scale) * item_scale + <= max_smem_per_block; + constexpr unsigned int block_size = use_fallback ? merge_sort_block_size(item_scale) : 128; + constexpr unsigned int items_per_thread = use_fallback ? merge_sort_items_per_thread(item_scale) : 1; - using type = merge_sort_block_merge_config; -}; + return merge_sort_block_merge_config_params{ + {block_size, 1, (1 << 17) + 70000}, + {128, 1}, + {block_size, items_per_thread} + }; +} /// \brief Default values are provided by \p radix_sort_onesweep_config_base. struct radix_sort_onesweep_config_params @@ -819,7 +817,7 @@ struct histogram_config_params unsigned int shared_impl_max_bins = 0; unsigned int shared_impl_histograms = 0; - kernel_config_params histogram_global_config = {0, 0}; + kernel_config_params histogram_global_config = histogram_config; }; } // namespace detail @@ -864,13 +862,14 @@ namespace detail { template -struct default_histogram_config_base +constexpr histogram_config_params histogram_config_params_base() { - static constexpr unsigned int item_scale - = ::rocprim::detail::ceiling_div(sizeof(Sample), sizeof(int)); + constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div(sizeof(Sample), sizeof(int)); + + constexpr kernel_config_params kernel_params + = {256, ::rocprim::max(8u / Channels / item_scale, 1u)}; - using type - = histogram_config>; + return histogram_config_params{kernel_params, 1024, 2048, 3, kernel_params}; }; struct adjacent_difference_config_tag @@ -1549,30 +1548,45 @@ namespace detail { template -struct default_merge_config_base +constexpr merge_config_params merge_config_params_base() { - static constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div( - ::rocprim::max(sizeof(Key), sizeof(Value)), sizeof(int)); - - using type = merge_config::value, - ::rocprim::max(1u, 10u / item_scale)>; -}; - -template -struct default_merge_config_base -{ - static constexpr unsigned int item_scale - = ::rocprim::detail::ceiling_div(sizeof(Key), sizeof(int)); - - using type = select_type< - select_type_case>, - select_type_case>, - select_type_case>, - merge_config::value, - ::rocprim::max(1u, 10u / item_scale)>>; -}; + if constexpr(std::is_same_v) + { + constexpr unsigned int item_scale + = ::rocprim::detail::ceiling_div(sizeof(Key), sizeof(int)); + + if constexpr(sizeof(Key) <= 2) + return merge_config_params{ + {256, 11} + }; + else if constexpr(sizeof(Key) <= 4) + return merge_config_params{ + {256, 10} + }; + else if constexpr(sizeof(Key) <= 8) + return merge_config_params{ + {256, 7} + }; + else + return merge_config_params{ + {fallback_block_size<256u, sizeof(Key), ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 10u / item_scale)} + }; + } + else + { + constexpr unsigned int item_scale = ::rocprim::detail::ceiling_div( + ::rocprim::max(sizeof(Key), sizeof(Value)), + sizeof(int)); + + return merge_config_params{ + {fallback_block_size<256u, + ::rocprim::max(sizeof(Key), sizeof(Value)), + ROCPRIM_WARP_SIZE_64>::value, + ::rocprim::max(1u, 10u / item_scale)} + }; + } +} } // namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/device_search_n.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/device_search_n.hpp index 549d19eaeca..4e7c168ec01 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/device_search_n.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/device_search_n.hpp @@ -79,9 +79,9 @@ hipError_t search_n_impl(void* temporary_storage, gpu target_gpu; ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); - const target get_target(target_arch, target_gpu); + const target current_target(target_arch, target_gpu); - const auto params = get_config(Config{}, get_target); + const auto params = get_config(Config{}, current_target); const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_thread = params.kernel_config.items_per_thread; const unsigned int items_per_block = block_size * items_per_thread; @@ -169,7 +169,7 @@ hipError_t search_n_impl(void* temporary_storage, } } }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(get_target, + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(current_target, search_n_normal_kernel, num_blocks, block_size, @@ -249,7 +249,7 @@ hipError_t search_n_impl(void* temporary_storage, } } }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(get_target, + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(current_target, search_n_find_heads_kernel, num_blocks_for_find_heads, block_size, @@ -299,7 +299,7 @@ hipError_t search_n_impl(void* temporary_storage, filtered_heads[atomic_add(tmp_output, 1)] = this_head; } }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(get_target, + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(current_target, search_n_heads_filter_kernel, num_blocks_for_heads_filter, block_size, @@ -392,7 +392,7 @@ hipError_t search_n_impl(void* temporary_storage, } } }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(get_target, + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(current_target, search_n_discard_heads_kernel, num_blocks_for_discard_heads, block_size, diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_histogram.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_histogram.hpp index f102b51b0e2..b4ac6e6b4ed 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_histogram.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_histogram.hpp @@ -124,63 +124,6 @@ struct HistogramSharedOp } }; -template -struct histogram_launch_plan -{ - using kernel_type = void (*)(Kernel); - - kernel_type kernel; - Kernel device_callback; - - unsigned int shared_impl_histograms = 0; - unsigned int max_grid_size = 0; - - void launch(dim3 grid, dim3 block, size_t shmem, hipStream_t stream) const - { - hipLaunchKernelGGL(HIP_KERNEL_NAME(kernel), grid, block, shmem, stream, device_callback); - } -}; - -template - class LaunchSelector> -auto make_histogram_launch_plan(rocprim::detail::target_arch arch, Kernel kernel) - -> histogram_launch_plan -{ - histogram_launch_plan plan{nullptr, std::move(kernel), 0u, 0u}; - - bool found = false; - - for_each_arch( - [&](auto arch_tag) - { - constexpr auto Arch = decltype(arch_tag)::value; - if(Arch != arch || found) - return; - - plan.kernel = trampoline_kernel; - - constexpr auto params = Config::template architecture_config::params; - - plan.shared_impl_histograms = params.shared_impl_histograms; - plan.max_grid_size = params.max_grid_size; - - found = true; - }); - if(!found) - { - constexpr auto Arch = rocprim::detail::target_arch::unknown; - - plan.kernel = trampoline_kernel; - - constexpr auto params = Config::template architecture_config::params; - plan.shared_impl_histograms = params.shared_impl_histograms; - plan.max_grid_size = params.max_grid_size; - } - return plan; -} - template::value_type; - - using config = wrapped_histogram_config; + using selector = histogram_config_selector; detail::target_arch target_arch; - hipError_t result = host_target_arch(stream, target_arch); - if(result != hipSuccess) - { - return result; - } - const histogram_config_params params = dispatch_target_arch(target_arch); + ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); + detail::gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + + const target current_target(target_arch, target_gpu); + const auto params = get_config(Config{}, current_target); const unsigned int block_size = params.histogram_config.block_size; const unsigned int items_per_thread = params.histogram_config.items_per_thread; const unsigned int shared_impl_max_bins = params.shared_impl_max_bins; @@ -333,14 +275,13 @@ inline hipError_t histogram_impl(void* temporary_storage, init_histogram(hist, bin_counts); }; - ROCPRIM_RETURN_ON_ERROR( - execute_launch_plan( - target_arch, - init_histogram_kernel, - ::rocprim::detail::ceiling_div(max_bins, block_size), - block_size, - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan( + current_target, + init_histogram_kernel, + ::rocprim::detail::ceiling_div(max_bins, block_size), + block_size, + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("init_histogram_kernel", max_bins, start); @@ -367,9 +308,9 @@ inline hipError_t histogram_impl(void* temporary_storage, 0, 0}; - auto plan = make_histogram_launch_plan( - target_arch, - op); + auto plan + = make_launch_plan(current_target, + op); const size_t block_histogram_bytes = total_shared_bins * sizeof(unsigned int); @@ -379,7 +320,7 @@ inline hipError_t histogram_impl(void* temporary_storage, // memory usage unsigned int chosen_shared_histograms = 0; int max_blocks_per_mp = 0; - for(unsigned int n = plan.shared_impl_histograms; n >= 1; n--) + for(unsigned int n = params.shared_impl_histograms; n >= 1; n--) { int blocks_per_mp; ROCPRIM_RETURN_ON_ERROR(hipOccupancyMaxActiveBlocksPerMultiprocessor( @@ -466,15 +407,14 @@ inline hipError_t histogram_impl(void* temporary_storage, block_id_count); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan( - target_arch, - histogram_private_global_kernel, - dim3(global_histogram_grid_size), - dim3(params.histogram_global_config.block_size), - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR( + execute_launch_plan( + current_target, + histogram_private_global_kernel, + dim3(global_histogram_grid_size), + dim3(params.histogram_global_config.block_size), + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("histogram_private_global_kernel", blocks_x * block_size * rows, @@ -502,14 +442,13 @@ inline hipError_t histogram_impl(void* temporary_storage, }; ROCPRIM_RETURN_ON_ERROR( - execute_launch_plan(target_arch, - histogram_global_kernel, - dim3(blocks_x, rows), - dim3(block_size, 1), - 0, - stream)); + execute_launch_plan( + current_target, + histogram_global_kernel, + dim3(blocks_x, rows), + dim3(block_size, 1), + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("histogram_global_kernel", blocks_x * block_size * rows, start); diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_histogram_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_histogram_config.hpp index b631e2a4b5b..c2d5785649f 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_histogram_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_histogram_config.hpp @@ -32,49 +32,33 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -template -struct wrapped_histogram_config +template +struct histogram_config_static_selector { - static_assert(std::is_same::value, - "Config must be a specialization of struct template histogram_config"); - - template - struct architecture_config - { - static constexpr histogram_config_params params = HistogramConfig{}; - }; + static constexpr auto block_size + = target_config2::params.histogram_config.block_size; }; -template -struct wrapped_histogram_config +template +struct histogram_global_config_static_selector { - template - struct architecture_config - { - static constexpr histogram_config_params params - = default_histogram_config(Arch), - Sample, - Channels, - ActiveChannels>{}; - }; + static constexpr auto block_size + = target_config2::params.histogram_global_config.block_size; }; -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template -template -constexpr histogram_config_params - wrapped_histogram_config:: - architecture_config::params; +template +struct histogram_config_selector +{ + using targets = histogram_targets; + using param_type = histogram_config_params; + + param_type params; -template -template -constexpr histogram_config_params - wrapped_histogram_config::architecture_config< - Arch>::params; -#endif // DOXYGEN_SHOULD_SKIP_THIS + template + constexpr histogram_config_selector(Target) + : params(histogram_config_picker()) + {} +}; } // end namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_merge.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_merge.hpp index fba3fd5c389..7765ee05d93 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_merge.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_merge.hpp @@ -42,32 +42,27 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -template -inline size_t get_merge_vsmem_size_per_block(detail::target_arch arch) +template +inline size_t get_merge_vsmem_size_per_block(detail::target t) { - std::optional vsmem_per_block; - for_each_arch( - [&](auto arch_tag) + using targets = typename Selector::targets; + + size_t vsmem_per_block = 0; + + targets::for_each( + [&](auto candidate) { - constexpr target_arch Arch = decltype(arch_tag)::value; - if(Arch != arch || vsmem_per_block) - return; - using ArchConfig = typename Config::template architecture_config; - using merge_kernel_impl_t = merge_kernel_impl_; - using merge_vsmem_helper_t = detail::vsmem_helper_impl; - - vsmem_per_block = merge_vsmem_helper_t::vsmem_per_block; + if(target{candidate} == most_common_config(t)) + { + using ArchConfig = target_config2; + using merge_kernel_impl_t = merge_kernel_impl_; + using merge_vsmem_helper_t = detail::vsmem_helper_impl; + + vsmem_per_block = merge_vsmem_helper_t::vsmem_per_block; + } }); - if(!vsmem_per_block) - { - using ArchConfig = typename Config::template architecture_config; - using merge_kernel_impl_t = merge_kernel_impl_; - using merge_vsmem_helper_t = detail::vsmem_helper_impl; - - vsmem_per_block = merge_vsmem_helper_t::vsmem_per_block; - } - return vsmem_per_block.value(); + return vsmem_per_block; } template::value_type; using value_type = typename std::iterator_traits::value_type; - using config = wrapped_merge_config; + using selector = merge_config_selector; detail::target_arch target_arch; - hipError_t result = detail::host_target_arch(stream, target_arch); - if(result != hipSuccess) - { - return result; - } - const merge_config_params params = detail::dispatch_target_arch(target_arch); + ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); + detail::gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + + const target current_target(target_arch, target_gpu); + const auto params = get_config(Config{}, current_target); const unsigned int block_size = params.kernel_config.block_size; const unsigned int half_block = block_size / 2; const unsigned int items_per_thread = params.kernel_config.items_per_thread; @@ -115,7 +110,7 @@ inline hipError_t merge_impl(void* temporary_storage, = ((input1_size + input2_size) + items_per_block - 1) / items_per_block; size_t virtual_shared_memory_size - = get_merge_vsmem_size_per_block(target_arch) + = get_merge_vsmem_size_per_block(current_target) * number_of_blocks; unsigned int* index = nullptr; @@ -172,12 +167,12 @@ inline hipError_t merge_impl(void* temporary_storage, compare_function); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - partition_kernel, - partition_blocks, - half_block, - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(current_target, + partition_kernel, + partition_blocks, + half_block, + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("partition_kernel", input1_size, start); @@ -212,12 +207,12 @@ inline hipError_t merge_impl(void* temporary_storage, storage); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - merge_kernel, - dim3(number_of_blocks), - dim3(block_size), - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(current_target, + merge_kernel, + dim3(number_of_blocks), + dim3(block_size), + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("merge_kernel", input1_size, start); diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_merge_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_merge_config.hpp index 5bf72abe710..a585d1ae747 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_merge_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_merge_config.hpp @@ -39,40 +39,18 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -// generic struct that instantiates custom configurations -template -struct wrapped_merge_config +template +struct merge_config_selector { - template - struct architecture_config - { - static constexpr merge_config_params params = Config(); - }; -}; + using targets = merge_targets; + using param_type = merge_config_params; -// specialized for rocprim::default_config, which instantiates the default_ALGO_config -template -struct wrapped_merge_config -{ - template - struct architecture_config - { - static constexpr merge_config_params params - = default_merge_config(Arch), KeyType, ValueType>{}; - }; -}; + param_type params; -#ifndef DOXYGEN_DOCUMENTATION_BUILD -template -template -constexpr merge_config_params - wrapped_merge_config::architecture_config::params; - -template -template -constexpr merge_config_params - wrapped_merge_config::architecture_config::params; -#endif // DOXYGEN_DOCUMENTATION_BUILD + template + constexpr merge_config_selector(Target) : params(merge_config_picker()) + {} +}; } // namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_merge_sort.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_merge_sort.hpp index 00b29896068..9014faa23e5 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_merge_sort.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_merge_sort.hpp @@ -72,17 +72,16 @@ inline hipError_t merge_sort_block_merge_impl( using value_type = typename std::iterator_traits::value_type; constexpr bool with_values = !std::is_same::value; - using config = wrapped_merge_sort_block_merge_config; + using selector = merge_sort_block_merge_config_selector; detail::target_arch target_arch; - hipError_t result = host_target_arch(stream, target_arch); - if(result != hipSuccess) - { - return result; - } - const merge_sort_block_merge_config_params params - = dispatch_target_arch(target_arch); + ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); + detail::gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + const target current_target(target_arch, target_gpu); + + const auto params = get_config(Config{}, current_target); const unsigned int merge_oddeven_block_size = params.merge_oddeven_config.block_size; const unsigned int merge_oddeven_items_per_thread = params.merge_oddeven_config.items_per_thread; @@ -211,10 +210,10 @@ inline hipError_t merge_sort_block_merge_impl( // Note: shared memory is not used in this kernel so there is no need to pass vsmem ROCPRIM_RETURN_ON_ERROR( - execute_launch_plan( - target_arch, + execute_launch_plan( + current_target, device_block_merge_mergepath_partition_kernel, dim3(merge_partition_number_of_blocks), dim3(merge_partition_block_size), @@ -260,10 +259,8 @@ inline hipError_t merge_sort_block_merge_impl( storage); }; ROCPRIM_RETURN_ON_ERROR( - execute_launch_plan( - target_arch, + execute_launch_plan( + current_target, device_block_merge_mergepath_kernel, calculate_grid_dim(merge_mergepath_number_of_blocks, merge_mergepath_block_size), @@ -299,10 +296,8 @@ inline hipError_t merge_sort_block_merge_impl( compare_function); }; ROCPRIM_RETURN_ON_ERROR( - execute_launch_plan( - target_arch, + execute_launch_plan( + current_target, device_block_merge_oddeven_kernel, dim3(merge_oddeven_number_of_blocks), dim3(merge_oddeven_block_size), @@ -383,17 +378,16 @@ inline hipError_t merge_sort_block_merge( using value_type = typename std::iterator_traits::value_type; constexpr bool with_values = !std::is_same::value; - using config = wrapped_merge_sort_block_merge_config; + using selector = merge_sort_block_merge_config_selector; detail::target_arch target_arch; - hipError_t result = host_target_arch(stream, target_arch); - if(result != hipSuccess) - { - return result; - } - const merge_sort_block_merge_config_params params - = dispatch_target_arch(target_arch); + ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); + detail::gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + + const target current_target(target_arch, target_gpu); + const auto params = get_config(Config{}, current_target); const unsigned int merge_mergepath_block_size = params.merge_mergepath_config.block_size; const unsigned int merge_mergepath_items_per_thread = params.merge_mergepath_config.items_per_thread; @@ -476,17 +470,16 @@ inline hipError_t merge_sort_block_sort(KeysInputIterator keys_input, using key_type = typename std::iterator_traits::value_type; using value_type = typename std::iterator_traits::value_type; - using config = wrapped_merge_sort_block_sort_config; + using selector = merge_sort_block_sort_config_selector; detail::target_arch target_arch; - hipError_t result = host_target_arch(stream, target_arch); - if(result != hipSuccess) - { - return result; - } - const merge_sort_block_sort_config_params params - = dispatch_target_arch(target_arch); + ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); + detail::gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + const target current_target(target_arch, target_gpu); + + const auto params = get_config(Config{}, current_target); sort_items_per_block = params.kernel_config.block_size * params.kernel_config.items_per_thread; const unsigned int sort_number_of_blocks = ceiling_div(size, sort_items_per_block); @@ -546,8 +539,8 @@ inline hipError_t merge_sort_block_sort(KeysInputIterator keys_input, storage); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan( - target_arch, + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan( + current_target, block_sort_kernel, calculate_grid_dim(sort_number_of_blocks, params.kernel_config.block_size), params.kernel_config.block_size, @@ -567,12 +560,17 @@ void TAssertEqualGreater() static_assert(A >= B, "A not greater or equal to B"); }; -template +template ROCPRIM_KERNEL void device_merge_sort_compile_time_verifier_arch() { - using BSArchConfig = typename BlockSortConfig::template architecture_config; - using BMArchConfig = typename BlockMergeConfig::template architecture_config; + using BSArchConfig = target_config2; + using BMArchConfig = target_config2; static constexpr auto bs_params = BSArchConfig::params; static constexpr auto bm_params = BMArchConfig::params; @@ -602,36 +600,51 @@ void device_merge_sort_compile_time_verifier_arch() "merge_mergepath_items_per_block"); } -template +template inline void device_merge_sort_compile_time_verifier() noexcept { - static const bool once = [] - { - for_each_arch( - [](auto arch_tag) - { - constexpr auto A = decltype(arch_tag)::value; - (void)&device_merge_sort_compile_time_verifier_arch; - }); - (void)&device_merge_sort_compile_time_verifier_arch; - return true; - }(); - (void)once; + // BSTargets and BMTargets can be different so we do not know at compile time + // the combination of configs that will be chosen. + using BSTargets = typename BSSelector::targets; + using BMTargets = typename BMSelector::targets; + + BSTargets::for_each( + [&](auto t) + { + constexpr target ct = most_common_config(target{t}); + (void)device_merge_sort_compile_time_verifier_arch; + }); + + BMTargets::for_each( + [&](auto t) + { + constexpr target ct = most_common_config(target{t}); + (void)device_merge_sort_compile_time_verifier_arch; + }); } -template -inline size_t merge_sort_vsmem_size_for_arch(size_t size) + class BlockSortSelector, + class BlockMergeSelector, + class Key, + class Value> +inline size_t merge_sort_vsmem_size_for_target(size_t size) { - using BSArchConfig = typename BlockSortConfig::template architecture_config; - using BMArchConfig = typename BlockMergeConfig::template architecture_config; + using BSArchConfig = target_config2; + using BMArchConfig = target_config2; static constexpr auto bs_params = BSArchConfig::params; static constexpr auto bm_params = BMArchConfig::params; @@ -679,35 +692,44 @@ inline size_t merge_sort_vsmem_size_for_arch(size_t size) return virtual_shared_memory_size; } -template -inline size_t get_merge_sort_vsmem_size(detail::target_arch arch, size_t size) noexcept +template +inline size_t get_merge_sort_vsmem_size(detail::target t, size_t size) noexcept { - std::optional out; + using BlockSortTarget = typename BlockSortSelector::targets; + using BlockMergeTarget = typename BlockMergeSelector::targets; - for_each_arch( - [&](auto arch_tag) + size_t vsmem_per_block = 0; + + BlockSortTarget::for_each( + [&](auto BScandidate) { - if(out) - return; - constexpr auto Arch = decltype(arch_tag)::value; - if(Arch != arch) - return; - - out = merge_sort_vsmem_size_for_arch(size); + if(target{BScandidate} == most_common_config(t)) + { + BlockMergeTarget::for_each( + [&](auto BMcandidate) + { + if(target{BMcandidate} == most_common_config(t)) + { + vsmem_per_block + = merge_sort_vsmem_size_for_target(size); + } + }); + } }); - if(!out) - { - out = merge_sort_vsmem_size_for_arch(size); - } - return *out; + + return vsmem_per_block; } template::type; using block_merge_config = typename std:: conditional::type; - using wrapped_bs_config - = wrapped_merge_sort_block_sort_config; - using wrapped_bm_config - = wrapped_merge_sort_block_merge_config; + using selector_bm = merge_sort_block_merge_config_selector; + using selector_bs = merge_sort_block_sort_config_selector; // Some helpful checks during compile-time - device_merge_sort_compile_time_verifier(); + device_merge_sort_compile_time_verifier(); unsigned int sort_items_per_block = 1; // We will get this later from the block_sort algorithm detail::target_arch target_arch; - hipError_t result = host_target_arch(stream, target_arch); - if(result != hipSuccess) - { - return result; - } - const merge_sort_block_merge_config_params params - = dispatch_target_arch(target_arch); + ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); + detail::gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + + const target current_target(target_arch, target_gpu); + + const auto params = get_config(block_merge_config{}, current_target); const bool use_mergepath = size > params.merge_oddeven_config.size_limit; const unsigned int merge_mergepath_items_per_block = params.merge_mergepath_config.block_size * params.merge_mergepath_config.items_per_thread; @@ -767,10 +790,12 @@ inline hipError_t merge_sort_impl( // Virtual shared memory part void* vsmem = nullptr; - size_t virtual_shared_memory_size - = get_merge_sort_vsmem_size( - target_arch, - size); + size_t virtual_shared_memory_size = get_merge_sort_vsmem_size(current_target, size); // temporary storage needed for both block merge and block sort size_t* d_merge_partitions = nullptr; diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_merge_sort_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_merge_sort_config.hpp index 4883c4a08a8..639e2d111a3 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_merge_sort_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_merge_sort_config.hpp @@ -86,77 +86,55 @@ namespace detail { // Sub algorithm block_merge: - -template -struct wrapped_merge_sort_block_merge_config +template +struct merge_oddeven_config_static_selector { - template - struct architecture_config - { - static constexpr merge_sort_block_merge_config_params params = MergeSortBlockMergeConfig(); - }; + static constexpr auto block_size + = target_config2::params.merge_oddeven_config.block_size; }; -template -struct wrapped_merge_sort_block_merge_config +template +struct merge_mergepath_partition_config_static_selector { - template - struct architecture_config - { - static constexpr merge_sort_block_merge_config_params params - = default_merge_sort_block_merge_config(Arch), Key, Value>(); - }; + static constexpr auto block_size = target_config2::params + .merge_mergepath_partition_config.block_size; }; -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template -template -constexpr merge_sort_block_merge_config_params - wrapped_merge_sort_block_merge_config:: - architecture_config::params; - -template -template -constexpr merge_sort_block_merge_config_params - wrapped_merge_sort_block_merge_config::architecture_config< - Arch>::params; -#endif // DOXYGEN_SHOULD_SKIP_THIS - -// Sub-algorithm block_sort: -template -struct wrapped_merge_sort_block_sort_config +template +struct merge_mergepath_config_static_selector { - template - struct architecture_config - { - static constexpr merge_sort_block_sort_config_params params = MergeSortBlockSortConfig(); - }; + static constexpr auto block_size + = target_config2::params.merge_mergepath_config.block_size; }; -template -struct wrapped_merge_sort_block_sort_config +template +struct merge_sort_block_merge_config_selector { - template - struct architecture_config - { - static constexpr merge_sort_block_sort_config_params params - = default_merge_sort_block_sort_config(Arch), Key, Value>(); - }; + using targets = merge_sort_block_merge_targets; + using param_type = merge_sort_block_merge_config_params; + + param_type params; + + template + constexpr merge_sort_block_merge_config_selector(Target) + : params(merge_sort_block_merge_config_picker()) + {} }; -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template -template -constexpr merge_sort_block_sort_config_params - wrapped_merge_sort_block_sort_config::architecture_config< - Arch>::params; - -template -template -constexpr merge_sort_block_sort_config_params - wrapped_merge_sort_block_sort_config::architecture_config< - Arch>::params; -#endif // DOXYGEN_SHOULD_SKIP_THIS +// Sub-algorithm block_sort: +template +struct merge_sort_block_sort_config_selector +{ + using targets = merge_sort_block_sort_targets; + using param_type = merge_sort_block_sort_config_params; + + param_type params; + + template + constexpr merge_sort_block_sort_config_selector(Target) + : params(merge_sort_block_sort_config_picker()) + {} +}; } // namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_partition.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_partition.hpp index 0b9ad04f7a0..513a29c4564 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_partition.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_partition.hpp @@ -165,8 +165,8 @@ inline hipError_t partition_impl(void* temporary_storage, detail::gpu target_gpu; ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); - const target get_target(target_arch, target_gpu); - const auto params = get_config(Config{}, get_target); + const target current_target(target_arch, target_gpu); + const auto params = get_config(Config{}, current_target); const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_thread = params.kernel_config.items_per_thread; const auto items_per_block = block_size * items_per_thread; @@ -212,7 +212,7 @@ inline hipError_t partition_impl(void* temporary_storage, value_type, flag_type, scan_state_type, - block_id_type>(get_target); + block_id_type>(current_target); virtual_shared_memory_size *= number_of_blocks; // temporary storage partition @@ -338,7 +338,7 @@ inline hipError_t partition_impl(void* temporary_storage, predicates...); }; ROCPRIM_RETURN_ON_ERROR( - execute_launch_plan(get_target, + execute_launch_plan(current_target, partition_kernel, dim3(current_number_of_blocks), dim3(block_size), diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_radix_sort.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_radix_sort.hpp index 16e1ca69d69..37fee2bf475 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_radix_sort.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_radix_sort.hpp @@ -107,9 +107,9 @@ hipError_t radix_sort_onesweep_global_offsets(KeysInputIterator keys_input, gpu target_gpu; ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); - const target get_target(target_arch, target_gpu); + const target current_target(target_arch, target_gpu); - const radix_sort_onesweep_config_params params = get_config(Config{}, get_target); + const radix_sort_onesweep_config_params params = get_config(Config{}, current_target); const unsigned int items_per_block = params.histogram.block_size * params.histogram.items_per_thread; @@ -153,7 +153,7 @@ hipError_t radix_sort_onesweep_global_offsets(KeysInputIterator keys_input, ROCPRIM_RETURN_ON_ERROR( execute_launch_plan( - get_target, + current_target, onesweep_histograms_kernel, dim3(blocks), dim3(params.histogram.block_size), @@ -176,7 +176,7 @@ hipError_t radix_sort_onesweep_global_offsets(KeysInputIterator keys_input, ROCPRIM_RETURN_ON_ERROR( execute_launch_plan( - get_target, + current_target, onesweep_scan_histograms_kernel, dim3(digit_places), // One block for every digit place. dim3(params.histogram.block_size), @@ -227,9 +227,9 @@ hipError_t radix_sort_onesweep_iteration( gpu target_gpu; ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); - const target get_target(target_arch, target_gpu); + const target current_target(target_arch, target_gpu); - const radix_sort_onesweep_config_params params = get_config(Config{}, get_target); + const radix_sort_onesweep_config_params params = get_config(Config{}, current_target); const unsigned int items_per_block = params.sort.block_size * params.sort.items_per_thread; const unsigned int current_radix_bits @@ -312,7 +312,7 @@ hipError_t radix_sort_onesweep_iteration( return execute_launch_plan( - get_target, + current_target, onesweep_iteration_kernel, dim3(blocks), dim3(params.sort.block_size), @@ -403,10 +403,10 @@ hipError_t radix_sort_onesweep_impl( gpu target_gpu; ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); - const target get_target(target_arch, target_gpu); + const target current_target(target_arch, target_gpu); const radix_sort_onesweep_config_params params - = get_config(Config{}, get_target); + = get_config(Config{}, current_target); const unsigned int sort_items_per_block = params.sort.block_size * params.sort.items_per_thread; diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_reduce.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_reduce.hpp index 1368cd0216a..c4a3da51d83 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_reduce.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_reduce.hpp @@ -80,7 +80,7 @@ namespace detail result_type>(input, size, output, initial_value, reduce_op); \ }; \ \ - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(get_target, \ + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(current_target, \ block_reduce_kernel, \ dim3(1), \ dim3(block_size), \ @@ -117,9 +117,9 @@ inline hipError_t reduce_impl(void* temporary_storage, gpu target_gpu; ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); - const target get_target(target_arch, target_gpu); + const target current_target(target_arch, target_gpu); - const auto params = get_config(Config{}, get_target); + const auto params = get_config(Config{}, current_target); const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_thread = params.kernel_config.items_per_thread; @@ -201,7 +201,7 @@ inline hipError_t reduce_impl(void* temporary_storage, initial_value, reduce_op); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(get_target, + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(current_target, block_reduce_kernel, dim3(current_blocks), dim3(block_size), diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_reduce_by_key.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_reduce_by_key.hpp index bf1d232cdb2..4b09981c36c 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_reduce_by_key.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_reduce_by_key.hpp @@ -141,9 +141,9 @@ hipError_t reduce_by_key_impl_wrapped_config(void* temporary gpu target_gpu; ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); - const target get_target(target_arch, target_gpu); + const target current_target(target_arch, target_gpu); - const auto params = get_config(Config{}, get_target); + const auto params = get_config(Config{}, current_target); using scan_state_type = reduce_by_key::lookback_scan_state_t; @@ -275,7 +275,7 @@ hipError_t reduce_by_key_impl_wrapped_config(void* temporary ordered_bid); }; ROCPRIM_RETURN_ON_ERROR( - execute_launch_plan(get_target, + execute_launch_plan(current_target, kernel, dim3(number_of_blocks_launch), dim3(block_size), diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_run_length_encode.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_run_length_encode.hpp index 949074a9472..7fd67e97220 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_run_length_encode.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_run_length_encode.hpp @@ -139,9 +139,9 @@ hipError_t run_length_encode_non_trivial_runs_impl(void* tempo gpu target_gpu; ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); - const target get_target(target_arch, target_gpu); + const target current_target(target_arch, target_gpu); - const auto params = get_config(non_trivial_config{}, get_target); + const auto params = get_config(non_trivial_config{}, current_target); using scan_state_type = ::rocprim::detail::lookback_scan_state; @@ -228,7 +228,7 @@ hipError_t run_length_encode_non_trivial_runs_impl(void* tempo }; ROCPRIM_RETURN_ON_ERROR( - execute_launch_plan(get_target, + execute_launch_plan(current_target, non_trivial_kernel, dim3(grid_size), dim3(block_size), diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_scan.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_scan.hpp index 82e06315632..06c7690427a 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_scan.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_scan.hpp @@ -94,9 +94,9 @@ inline auto scan_impl(void* temporary_storage, gpu target_gpu; ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); - const target get_target(target_arch, target_gpu); + const target current_target(target_arch, target_gpu); - const auto params = get_config(Config{}, get_target); + const auto params = get_config(Config{}, current_target); const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_thread = params.kernel_config.items_per_thread; @@ -220,7 +220,7 @@ inline auto scan_impl(void* temporary_storage, block_id); }; ROCPRIM_RETURN_ON_ERROR( - execute_launch_plan(get_target, + execute_launch_plan(current_target, lookback_scan_kernel, dim3(grid_size), dim3(block_size), @@ -292,7 +292,7 @@ inline auto scan_impl(void* temporary_storage, // Save values into output array block_store_type().store(output, values, size, storage.store); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(get_target, + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(current_target, single_scan_kernel, dim3(1), dim3(block_size), diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_scan_by_key.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_scan_by_key.hpp index 56ef2e9baed..fc755f4ee15 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_scan_by_key.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_scan_by_key.hpp @@ -178,8 +178,8 @@ inline hipError_t scan_by_key_impl(void* const temporary_storage, detail::gpu target_gpu; ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); - const target get_target(target_arch, target_gpu); - const auto params = get_config(Config{}, get_target); + const target current_target(target_arch, target_gpu); + const auto params = get_config(Config{}, current_target); using wrapped_type = ::rocprim::tuple; using scan_state_type = detail::lookback_scan_state; @@ -340,7 +340,7 @@ inline hipError_t scan_by_key_impl(void* const temporary_storage, ordered_bid); }; ROCPRIM_RETURN_ON_ERROR( - execute_launch_plan(get_target, + execute_launch_plan(current_target, device_scan_by_key_kernel, dim3(scan_blocks), dim3(block_size), diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_segmented_radix_sort.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_segmented_radix_sort.hpp index 213170d0723..320ac214bab 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_segmented_radix_sort.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_segmented_radix_sort.hpp @@ -162,9 +162,9 @@ inline hipError_t segmented_radix_sort_impl( gpu target_gpu; ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); - const target get_target(target_arch, target_gpu); + const target current_target(target_arch, target_gpu); - const auto params = get_config(Config{}, get_target); + const auto params = get_config(Config{}, current_target); static constexpr bool with_values = !std::is_same::value; const bool partitioning_allowed = params.warp_sort_config.partitioning_allowed; @@ -357,7 +357,7 @@ inline hipError_t segmented_radix_sort_impl( }; ROCPRIM_RETURN_ON_ERROR( - execute_launch_plan(get_target, + execute_launch_plan(current_target, segmented_sort_large_kernel, dim3(large_segment_count), dim3(params.kernel_config.block_size), @@ -398,7 +398,7 @@ inline hipError_t segmented_radix_sort_impl( execute_launch_plan( - get_target, + current_target, segmented_sort_medium_kernel, dim3(medium_segment_grid_size), dim3(params.warp_sort_config.block_size_medium), @@ -439,7 +439,7 @@ inline hipError_t segmented_radix_sort_impl( execute_launch_plan( - get_target, + current_target, segmented_sort_small_kernel, dim3(small_segment_grid_size), dim3(params.warp_sort_config.block_size_small), @@ -474,7 +474,7 @@ inline hipError_t segmented_radix_sort_impl( }; ROCPRIM_RETURN_ON_ERROR( - execute_launch_plan(get_target, + execute_launch_plan(current_target, segmented_sort_kernel, dim3(segments), dim3(params.kernel_config.block_size), diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_segmented_reduce.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_segmented_reduce.hpp index 8151620a1df..c32b6b5186b 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_segmented_reduce.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_segmented_reduce.hpp @@ -71,9 +71,9 @@ inline hipError_t segmented_reduce_impl(void* temporary_storage, gpu target_gpu; ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); - const target get_target(target_arch, target_gpu); + const target current_target(target_arch, target_gpu); - const auto params = get_config(Config{}, get_target); + const auto params = get_config(Config{}, current_target); const unsigned int block_size = params.kernel_config.block_size; @@ -104,7 +104,7 @@ inline hipError_t segmented_reduce_impl(void* temporary_storage, static_cast(initial_value)); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(get_target, + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(current_target, segmented_reduce_kernel, dim3(segments), dim3(block_size), diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_segmented_scan.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_segmented_scan.hpp index df3af814961..d40ae1c3837 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_segmented_scan.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_segmented_scan.hpp @@ -102,9 +102,9 @@ inline hipError_t segmented_scan_impl(void* temporary_storage, detail::gpu target_gpu; ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); - const target get_target(target_arch, target_gpu); + const target current_target(target_arch, target_gpu); - const auto params = get_config(Config{}, get_target); + const auto params = get_config(Config{}, current_target); const unsigned int block_size = params.kernel_config.block_size; @@ -135,7 +135,7 @@ inline hipError_t segmented_scan_impl(void* temporary_storage, scan_op); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(get_target, + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(current_target, segmented_scan_kernel, dim3(segments), dim3(block_size), diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_transform.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_transform.hpp index b1c1679e9e3..8880dcf242f 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_transform.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_transform.hpp @@ -73,9 +73,9 @@ inline hipError_t transform_impl(InputIterator input, detail::gpu target_gpu; ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); - const target get_target(target_arch, target_gpu); + const target current_target(target_arch, target_gpu); - const auto params = get_config(Config{}, get_target); + const auto params = get_config(Config{}, current_target); const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_thread = params.kernel_config.items_per_thread; @@ -123,7 +123,7 @@ inline hipError_t transform_impl(InputIterator input, transform_op); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(get_target, + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(current_target, transform_kernel, current_blocks, block_size, diff --git a/projects/rocprim/rocprim/include/rocprim/device/specialization/device_radix_block_sort.hpp b/projects/rocprim/rocprim/include/rocprim/device/specialization/device_radix_block_sort.hpp index f01012020be..5fa445e8b49 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/specialization/device_radix_block_sort.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/specialization/device_radix_block_sort.hpp @@ -60,9 +60,9 @@ inline hipError_t radix_sort_block_sort(KeysInputIterator keys_input, gpu target_gpu; ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); - const target get_target(target_arch, target_gpu); + const target current_target(target_arch, target_gpu); - const auto params = get_config(Config{}, get_target); + const auto params = get_config(Config{}, current_target); sort_items_per_block = params.block_size * params.items_per_thread; const unsigned int sort_number_of_blocks = ceiling_div(size, sort_items_per_block); @@ -102,7 +102,7 @@ inline hipError_t radix_sort_block_sort(KeysInputIterator keys_input, ROCPRIM_RETURN_ON_ERROR( execute_launch_plan( - get_target, + current_target, radix_sort_block_sort_kernel, dim3(sort_number_of_blocks), dim3(params.block_size), diff --git a/projects/rocprim/rocprim/include/rocprim/device/specialization/device_radix_merge_sort.hpp b/projects/rocprim/rocprim/include/rocprim/device/specialization/device_radix_merge_sort.hpp index 21da5f37cce..681b5bd3746 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/specialization/device_radix_merge_sort.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/specialization/device_radix_merge_sort.hpp @@ -220,20 +220,21 @@ hipError_t radix_sort_merge_impl( using merge_sort_block_merge_config = typename std:: conditional::type; + using selector_bm = merge_sort_block_merge_config_selector; + using selector_bs = merge_sort_block_sort_config_selector; + // Wrap our radix_sort_block_sort kernel config in a merge_sort_block_sort_config // just so device_merge_sort_compile_time_verifier can check. using merge_sort_block_sort_config = merge_sort_block_sort_config; - using wrapped_bs_config - = wrapped_merge_sort_block_sort_config; - using wrapped_bm_config = wrapped_merge_sort_block_merge_config; // Some helpful checks during compile-time - device_merge_sort_compile_time_verifier(); + device_merge_sort_compile_time_verifier(); // We will get this later from the block_sort algorithm unsigned int sort_items_per_block diff --git a/projects/rocprim/test/rocprim/test_device_scan.cpp b/projects/rocprim/test/rocprim/test_device_scan.cpp index e8d5badfc2e..aecb900a1c1 100644 --- a/projects/rocprim/test/rocprim/test_device_scan.cpp +++ b/projects/rocprim/test/rocprim/test_device_scan.cpp @@ -316,9 +316,9 @@ TYPED_TEST(RocprimDeviceScanTests, LookBackScan) rocprim::detail::gpu target_gpu; HIP_CHECK(rocprim::detail::host_target_gpu(stream, target_gpu)); - const rocprim::detail::target get_target(target_arch, target_gpu); + const rocprim::detail::target current_target(target_arch, target_gpu); - const auto params = rocprim::detail::get_config(Config{}, get_target); + const auto params = rocprim::detail::get_config(Config{}, current_target); // For non-associative operations in inclusive scan // intermediate results use the type of input iterator, then @@ -503,7 +503,7 @@ TYPED_TEST(RocprimDeviceScanTests, LookBackScan) ordered_bid); }; return rocprim::detail::execute_launch_plan( - get_target, + current_target, lookback_scan_kernel, dim3(grid_size), dim3(block_size), @@ -570,9 +570,9 @@ TYPED_TEST(RocprimDeviceScanTests, LookBackScanGetCompleteValue) rocprim::detail::gpu target_gpu; HIP_CHECK(rocprim::detail::host_target_gpu(stream, target_gpu)); - const rocprim::detail::target get_target(target_arch, target_gpu); + const rocprim::detail::target current_target(target_arch, target_gpu); - const auto params = rocprim::detail::get_config(Config{}, get_target); + const auto params = rocprim::detail::get_config(Config{}, current_target); // For non-associative operations in inclusive scan // intermediate results use the type of input iterator, then @@ -746,7 +746,7 @@ TYPED_TEST(RocprimDeviceScanTests, LookBackScanGetCompleteValue) false, ordered_bid); }; - return rocprim::detail::execute_launch_plan(get_target, + return rocprim::detail::execute_launch_plan(current_target, lookback_scan_kernel, dim3(grid_size), dim3(block_size), diff --git a/projects/rocprim/test/rocprim/test_device_search_n.cpp b/projects/rocprim/test/rocprim/test_device_search_n.cpp index 07387e453e8..1b8eee34429 100644 --- a/projects/rocprim/test/rocprim/test_device_search_n.cpp +++ b/projects/rocprim/test/rocprim/test_device_search_n.cpp @@ -913,9 +913,9 @@ TYPED_TEST(RocprimDeviceSearchNTests, NoiseTest_1block) rocprim::detail::gpu target_gpu; HIP_CHECK(rocprim::detail::host_target_gpu(stream, target_gpu)); - const rocprim::detail::target get_target(target_arch, target_gpu); + const rocprim::detail::target current_target(target_arch, target_gpu); - const auto params = rocprim::detail::get_config(config{}, get_target); + const auto params = rocprim::detail::get_config(config{}, current_target); const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_thread = params.kernel_config.items_per_thread; const unsigned int items_per_block = block_size * items_per_thread; @@ -1039,9 +1039,9 @@ TYPED_TEST(RocprimDeviceSearchNTests, NoiseTest_2block) rocprim::detail::gpu target_gpu; HIP_CHECK(rocprim::detail::host_target_gpu(stream, target_gpu)); - const rocprim::detail::target get_target(target_arch, target_gpu); + const rocprim::detail::target current_target(target_arch, target_gpu); - const auto params = rocprim::detail::get_config(config{}, get_target); + const auto params = rocprim::detail::get_config(config{}, current_target); const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_thread = params.kernel_config.items_per_thread; const unsigned int items_per_block = block_size * items_per_thread; @@ -1165,9 +1165,9 @@ TYPED_TEST(RocprimDeviceSearchNTests, NoiseTest_3block) rocprim::detail::gpu target_gpu; HIP_CHECK(rocprim::detail::host_target_gpu(stream, target_gpu)); - const rocprim::detail::target get_target(target_arch, target_gpu); + const rocprim::detail::target current_target(target_arch, target_gpu); - const auto params = rocprim::detail::get_config(config{}, get_target); + const auto params = rocprim::detail::get_config(config{}, current_target); const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_thread = params.kernel_config.items_per_thread; const unsigned int items_per_block = block_size * items_per_thread; @@ -1291,9 +1291,9 @@ TYPED_TEST(RocprimDeviceSearchNTests, MultiResult1) rocprim::detail::gpu target_gpu; HIP_CHECK(rocprim::detail::host_target_gpu(stream, target_gpu)); - const rocprim::detail::target get_target(target_arch, target_gpu); + const rocprim::detail::target current_target(target_arch, target_gpu); - const auto params = rocprim::detail::get_config(config{}, get_target); + const auto params = rocprim::detail::get_config(config{}, current_target); const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_thread = params.kernel_config.items_per_thread; const unsigned int items_per_block = block_size * items_per_thread; @@ -1418,9 +1418,9 @@ TYPED_TEST(RocprimDeviceSearchNTests, MultiResult2) rocprim::detail::gpu target_gpu; HIP_CHECK(rocprim::detail::host_target_gpu(stream, target_gpu)); - const rocprim::detail::target get_target(target_arch, target_gpu); + const rocprim::detail::target current_target(target_arch, target_gpu); - const auto params = rocprim::detail::get_config(config{}, get_target); + const auto params = rocprim::detail::get_config(config{}, current_target); const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_thread = params.kernel_config.items_per_thread; const unsigned int items_per_block = block_size * items_per_thread; diff --git a/projects/rocprim/test/rocprim/test_linking_new_scan.hpp b/projects/rocprim/test/rocprim/test_linking_new_scan.hpp index ba048b02204..a38953bc99e 100644 --- a/projects/rocprim/test/rocprim/test_linking_new_scan.hpp +++ b/projects/rocprim/test/rocprim/test_linking_new_scan.hpp @@ -106,9 +106,9 @@ inline auto scan_impl(void* temporary_storage, rocprim::detail::gpu target_gpu; ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); - const target get_target(target_arch, target_gpu); + const target current_target(target_arch, target_gpu); - const auto params = get_config(Config{}, get_target); + const auto params = get_config(Config{}, current_target); const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_thread = params.kernel_config.items_per_thread; @@ -135,7 +135,7 @@ inline auto scan_impl(void* temporary_storage, output, scan_op); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(get_target, + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(current_target, single_scan_kernel, dim3(1), dim3(block_size), From 9cf6496bf5db4dd54655b4ae4bdda8ecc0dd77a4 Mon Sep 17 00:00:00 2001 From: Nick Breed Date: Thu, 13 Nov 2025 11:11:15 +0000 Subject: [PATCH 08/26] Update the config for radix_onesweep based on upstream changes --- .../config/device_radix_sort_onesweep.hpp | 1306 ++++++++--------- 1 file changed, 653 insertions(+), 653 deletions(-) diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_radix_sort_onesweep.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_radix_sort_onesweep.hpp index c761e93bc8d..9c942f95e3f 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_radix_sort_onesweep.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_radix_sort_onesweep.hpp @@ -52,8 +52,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 1}, - kernel_config_params{1024, 1}, + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, 8, block_radix_rank_algorithm::match }; @@ -64,8 +64,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 4}, - kernel_config_params{1024, 4}, + kernel_config_params{1024, 18}, + kernel_config_params{1024, 18}, 8, block_radix_rank_algorithm::match }; @@ -76,8 +76,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 4}, - kernel_config_params{1024, 4}, + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, 8, block_radix_rank_algorithm::match }; @@ -88,8 +88,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 4}, - kernel_config_params{1024, 4}, + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, 8, block_radix_rank_algorithm::match }; @@ -100,8 +100,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 4}, - kernel_config_params{1024, 4}, + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, 8, block_radix_rank_algorithm::match }; @@ -124,8 +124,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 1}, - kernel_config_params{1024, 1}, + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, 8, block_radix_rank_algorithm::match }; @@ -148,8 +148,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 8}, - kernel_config_params{1024, 8}, + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, 8, block_radix_rank_algorithm::match }; @@ -172,8 +172,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 4}, - kernel_config_params{1024, 4}, + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, 8, block_radix_rank_algorithm::match }; @@ -195,8 +195,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{128, 22}, - kernel_config_params{128, 22}, + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, 8, block_radix_rank_algorithm::match }; @@ -217,8 +217,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 12}, - kernel_config_params{1024, 12}, + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, 8, block_radix_rank_algorithm::match }; @@ -240,8 +240,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 22}, - kernel_config_params{512, 22}, + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, 8, block_radix_rank_algorithm::match }; @@ -251,8 +251,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 22}, - kernel_config_params{256, 22}, + kernel_config_params{256, 18}, + kernel_config_params{256, 18}, 8, block_radix_rank_algorithm::match }; @@ -311,8 +311,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 1}, - kernel_config_params{1024, 1}, + kernel_config_params{512, 4}, + kernel_config_params{512, 4}, 8, block_radix_rank_algorithm::match }; @@ -335,8 +335,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 1}, - kernel_config_params{1024, 1}, + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, 8, block_radix_rank_algorithm::match }; @@ -347,8 +347,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 1}, - kernel_config_params{1024, 1}, + kernel_config_params{1024, 18}, + kernel_config_params{1024, 18}, 8, block_radix_rank_algorithm::match }; @@ -359,8 +359,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 4}, - kernel_config_params{1024, 4}, + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, 8, block_radix_rank_algorithm::match }; @@ -371,8 +371,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 4}, - kernel_config_params{1024, 4}, + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, 8, block_radix_rank_algorithm::match }; @@ -383,8 +383,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 4}, - kernel_config_params{1024, 4}, + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, 8, block_radix_rank_algorithm::match }; @@ -419,8 +419,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 6}, - kernel_config_params{1024, 6}, + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, 8, block_radix_rank_algorithm::match }; @@ -431,8 +431,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 4}, - kernel_config_params{1024, 4}, + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, 8, block_radix_rank_algorithm::match }; @@ -455,8 +455,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 4}, - kernel_config_params{1024, 4}, + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, 8, block_radix_rank_algorithm::match }; @@ -479,8 +479,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{128, 22}, - kernel_config_params{128, 22}, + kernel_config_params{1024, 22}, + kernel_config_params{1024, 22}, 8, block_radix_rank_algorithm::match }; @@ -503,8 +503,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 8}, - kernel_config_params{1024, 8}, + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, 8, block_radix_rank_algorithm::match }; @@ -515,8 +515,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 12}, - kernel_config_params{1024, 12}, + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, 8, block_radix_rank_algorithm::match }; @@ -527,8 +527,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 16}, - kernel_config_params{512, 16}, + kernel_config_params{512, 22}, + kernel_config_params{512, 22}, 8, block_radix_rank_algorithm::match }; @@ -550,8 +550,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{128, 22}, - kernel_config_params{128, 22}, + kernel_config_params{1024, 22}, + kernel_config_params{1024, 22}, 8, block_radix_rank_algorithm::match }; @@ -572,8 +572,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 12}, - kernel_config_params{1024, 12}, + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, 8, block_radix_rank_algorithm::match }; @@ -583,8 +583,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 8}, - kernel_config_params{1024, 8}, + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, 8, block_radix_rank_algorithm::match }; @@ -595,8 +595,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 22}, - kernel_config_params{256, 22}, + kernel_config_params{256, 18}, + kernel_config_params{256, 18}, 8, block_radix_rank_algorithm::match }; @@ -628,8 +628,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{128, 16}, - kernel_config_params{128, 16}, + kernel_config_params{1024, 18}, + kernel_config_params{1024, 18}, 8, block_radix_rank_algorithm::match }; @@ -640,8 +640,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 4}, - kernel_config_params{1024, 4}, + kernel_config_params{1024, 18}, + kernel_config_params{1024, 18}, 8, block_radix_rank_algorithm::match }; @@ -652,8 +652,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 4}, - kernel_config_params{1024, 4}, + kernel_config_params{1024, 22}, + kernel_config_params{1024, 22}, 8, block_radix_rank_algorithm::match }; @@ -664,8 +664,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 4}, - kernel_config_params{1024, 4}, + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, 8, block_radix_rank_algorithm::match }; @@ -676,8 +676,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 12}, - kernel_config_params{512, 12}, + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, 8, block_radix_rank_algorithm::match }; @@ -688,8 +688,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 6}, - kernel_config_params{1024, 6}, + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, 8, block_radix_rank_algorithm::match }; @@ -700,8 +700,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{128, 22}, - kernel_config_params{128, 22}, + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, 8, block_radix_rank_algorithm::match }; @@ -724,8 +724,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 12}, - kernel_config_params{1024, 12}, + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, 8, block_radix_rank_algorithm::match }; @@ -736,8 +736,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 6}, - kernel_config_params{1024, 6}, + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, 8, block_radix_rank_algorithm::match }; @@ -748,8 +748,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 8}, - kernel_config_params{1024, 8}, + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, 8, block_radix_rank_algorithm::match }; @@ -760,8 +760,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 22}, - kernel_config_params{512, 22}, + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, 8, block_radix_rank_algorithm::match }; @@ -771,8 +771,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{128, 22}, - kernel_config_params{128, 22}, + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, 8, block_radix_rank_algorithm::match }; @@ -793,8 +793,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 6}, - kernel_config_params{1024, 6}, + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, 8, block_radix_rank_algorithm::match }; @@ -804,8 +804,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 12}, - kernel_config_params{1024, 12}, + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, 8, block_radix_rank_algorithm::match }; @@ -827,8 +827,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 18}, - kernel_config_params{512, 18}, + kernel_config_params{256, 18}, + kernel_config_params{256, 18}, 8, block_radix_rank_algorithm::match }; @@ -839,8 +839,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 1}, - kernel_config_params{1024, 1}, + kernel_config_params{1024, 18}, + kernel_config_params{1024, 18}, 8, block_radix_rank_algorithm::match }; @@ -851,8 +851,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{128, 16}, - kernel_config_params{128, 16}, + kernel_config_params{1024, 18}, + kernel_config_params{1024, 18}, 8, block_radix_rank_algorithm::match }; @@ -863,8 +863,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{128, 16}, - kernel_config_params{128, 16}, + kernel_config_params{512, 18}, + kernel_config_params{512, 18}, 8, block_radix_rank_algorithm::match }; @@ -875,8 +875,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{128, 16}, - kernel_config_params{128, 16}, + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, 8, block_radix_rank_algorithm::match }; @@ -887,8 +887,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{128, 16}, - kernel_config_params{128, 16}, + kernel_config_params{256, 18}, + kernel_config_params{256, 18}, 8, block_radix_rank_algorithm::match }; @@ -899,8 +899,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 6}, - kernel_config_params{512, 6}, + kernel_config_params{1024, 1}, + kernel_config_params{1024, 1}, 8, block_radix_rank_algorithm::match }; @@ -911,8 +911,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{128, 22}, - kernel_config_params{128, 22}, + kernel_config_params{1024, 18}, + kernel_config_params{1024, 18}, 8, block_radix_rank_algorithm::match }; @@ -923,8 +923,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 4}, - kernel_config_params{1024, 4}, + kernel_config_params{1024, 18}, + kernel_config_params{1024, 18}, 8, block_radix_rank_algorithm::match }; @@ -935,8 +935,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 4}, - kernel_config_params{1024, 4}, + kernel_config_params{1024, 18}, + kernel_config_params{1024, 18}, 8, block_radix_rank_algorithm::match }; @@ -947,8 +947,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 4}, - kernel_config_params{1024, 4}, + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, 8, block_radix_rank_algorithm::match }; @@ -959,8 +959,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 16}, - kernel_config_params{256, 16}, + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, 8, block_radix_rank_algorithm::match }; @@ -971,8 +971,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 4}, - kernel_config_params{1024, 4}, + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, 8, block_radix_rank_algorithm::match }; @@ -983,8 +983,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{128, 22}, - kernel_config_params{128, 22}, + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, 8, block_radix_rank_algorithm::match }; @@ -1019,8 +1019,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 4}, - kernel_config_params{1024, 4}, + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, 8, block_radix_rank_algorithm::match }; @@ -1031,8 +1031,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 8}, - kernel_config_params{1024, 8}, + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, 8, block_radix_rank_algorithm::match }; @@ -1043,8 +1043,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 22}, - kernel_config_params{512, 22}, + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, 8, block_radix_rank_algorithm::match }; @@ -1055,8 +1055,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{128, 22}, - kernel_config_params{128, 22}, + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, 8, block_radix_rank_algorithm::match }; @@ -1079,8 +1079,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 12}, - kernel_config_params{1024, 12}, + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, 8, block_radix_rank_algorithm::match }; @@ -1115,8 +1115,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 18}, - kernel_config_params{512, 18}, + kernel_config_params{256, 18}, + kernel_config_params{256, 18}, 8, block_radix_rank_algorithm::match }; @@ -1126,8 +1126,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{128, 22}, - kernel_config_params{128, 22}, + kernel_config_params{512, 22}, + kernel_config_params{512, 22}, 8, block_radix_rank_algorithm::match }; @@ -1171,8 +1171,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 18}, - kernel_config_params{512, 18}, + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, 8, block_radix_rank_algorithm::match }; @@ -1182,8 +1182,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 18}, - kernel_config_params{512, 18}, + kernel_config_params{256, 18}, + kernel_config_params{256, 18}, 8, block_radix_rank_algorithm::match }; @@ -1195,7 +1195,7 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< template constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< std::is_same>::value, + comp_target>::value, radix_sort_onesweep_config_params> { // Based on key_type = double, value_type = rocprim::int128_t @@ -1204,8 +1204,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 6}, - kernel_config_params{256, 6}, + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, 8, block_radix_rank_algorithm::match }; @@ -1216,8 +1216,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, 8, block_radix_rank_algorithm::match }; @@ -1228,8 +1228,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, 8, block_radix_rank_algorithm::match }; @@ -1240,8 +1240,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 6}, - kernel_config_params{512, 6}, + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, 8, block_radix_rank_algorithm::match }; @@ -1252,8 +1252,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 6}, - kernel_config_params{512, 6}, + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, 8, block_radix_rank_algorithm::match }; @@ -1264,8 +1264,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, 8, block_radix_rank_algorithm::match }; @@ -1276,8 +1276,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, 8, block_radix_rank_algorithm::match }; @@ -1288,8 +1288,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, 8, block_radix_rank_algorithm::match }; @@ -1300,8 +1300,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 22}, - kernel_config_params{256, 22}, + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, 8, block_radix_rank_algorithm::match }; @@ -1312,8 +1312,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 16}, - kernel_config_params{256, 16}, + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, 8, block_radix_rank_algorithm::match }; @@ -1324,8 +1324,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 12}, - kernel_config_params{512, 12}, + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, 8, block_radix_rank_algorithm::match }; @@ -1336,8 +1336,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 12}, - kernel_config_params{512, 12}, + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, 8, block_radix_rank_algorithm::match }; @@ -1347,8 +1347,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, 8, block_radix_rank_algorithm::match }; @@ -1358,8 +1358,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, 8, block_radix_rank_algorithm::match }; @@ -1369,8 +1369,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 12}, - kernel_config_params{512, 12}, + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, 8, block_radix_rank_algorithm::match }; @@ -1380,8 +1380,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 22}, - kernel_config_params{256, 22}, + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, 8, block_radix_rank_algorithm::match }; @@ -1392,8 +1392,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 12}, - kernel_config_params{512, 12}, + kernel_config_params{256, 16}, + kernel_config_params{256, 16}, 8, block_radix_rank_algorithm::match }; @@ -1403,8 +1403,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 12}, - kernel_config_params{512, 12}, + kernel_config_params{256, 18}, + kernel_config_params{256, 18}, 8, block_radix_rank_algorithm::match }; @@ -1415,8 +1415,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 6}, - kernel_config_params{512, 6}, + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, 8, block_radix_rank_algorithm::match }; @@ -1427,8 +1427,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 6}, - kernel_config_params{512, 6}, + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, 8, block_radix_rank_algorithm::match }; @@ -1439,8 +1439,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 6}, - kernel_config_params{512, 6}, + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, 8, block_radix_rank_algorithm::match }; @@ -1451,8 +1451,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 6}, - kernel_config_params{512, 6}, + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, 8, block_radix_rank_algorithm::match }; @@ -1463,8 +1463,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 6}, - kernel_config_params{512, 6}, + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, 8, block_radix_rank_algorithm::match }; @@ -1475,8 +1475,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 6}, - kernel_config_params{256, 6}, + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, 8, block_radix_rank_algorithm::match }; @@ -1487,8 +1487,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 6}, - kernel_config_params{256, 6}, + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, 8, block_radix_rank_algorithm::match }; @@ -1499,8 +1499,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 6}, - kernel_config_params{512, 6}, + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, 8, block_radix_rank_algorithm::match }; @@ -1511,8 +1511,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, 8, block_radix_rank_algorithm::match }; @@ -1523,8 +1523,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 6}, - kernel_config_params{512, 6}, + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, 8, block_radix_rank_algorithm::match }; @@ -1535,8 +1535,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 6}, - kernel_config_params{512, 6}, + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, 8, block_radix_rank_algorithm::match }; @@ -1547,8 +1547,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 6}, - kernel_config_params{512, 6}, + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, 8, block_radix_rank_algorithm::match }; @@ -1559,8 +1559,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, 8, block_radix_rank_algorithm::match }; @@ -1571,8 +1571,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, 8, block_radix_rank_algorithm::match }; @@ -1583,8 +1583,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 22}, - kernel_config_params{256, 22}, + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, 8, block_radix_rank_algorithm::match }; @@ -1595,8 +1595,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 22}, - kernel_config_params{256, 22}, + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, 8, block_radix_rank_algorithm::match }; @@ -1607,8 +1607,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 18}, - kernel_config_params{256, 18}, + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, 8, block_radix_rank_algorithm::match }; @@ -1619,8 +1619,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 22}, - kernel_config_params{256, 22}, + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, 8, block_radix_rank_algorithm::match }; @@ -1631,8 +1631,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 6}, - kernel_config_params{512, 6}, + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, 8, block_radix_rank_algorithm::match }; @@ -1643,8 +1643,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, 8, block_radix_rank_algorithm::match }; @@ -1655,8 +1655,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 22}, - kernel_config_params{512, 22}, + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, 8, block_radix_rank_algorithm::match }; @@ -1667,9 +1667,9 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 12}, - kernel_config_params{512, 12}, - 4, + kernel_config_params{512, 18}, + kernel_config_params{512, 18}, + 8, block_radix_rank_algorithm::match }; } @@ -1679,8 +1679,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 12}, - kernel_config_params{512, 12}, + kernel_config_params{256, 16}, + kernel_config_params{256, 16}, 8, block_radix_rank_algorithm::match }; @@ -1691,8 +1691,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 12}, - kernel_config_params{512, 12}, + kernel_config_params{256, 18}, + kernel_config_params{256, 18}, 8, block_radix_rank_algorithm::match }; @@ -1702,8 +1702,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 6}, - kernel_config_params{512, 6}, + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, 8, block_radix_rank_algorithm::match }; @@ -1713,8 +1713,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 6}, - kernel_config_params{512, 6}, + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, 8, block_radix_rank_algorithm::match }; @@ -1724,8 +1724,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 16}, - kernel_config_params{256, 16}, + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, 8, block_radix_rank_algorithm::match }; @@ -1735,8 +1735,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 8}, - kernel_config_params{512, 8}, + kernel_config_params{256, 16}, + kernel_config_params{256, 16}, 8, block_radix_rank_algorithm::match }; @@ -1747,10 +1747,10 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 16}, - kernel_config_params{512, 16}, - 4, - block_radix_rank_algorithm::basic + kernel_config_params{256, 16}, + kernel_config_params{256, 16}, + 8, + block_radix_rank_algorithm::match }; } // Based on key_type = int8_t, value_type = empty_type @@ -1758,8 +1758,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 12}, - kernel_config_params{512, 12}, + kernel_config_params{256, 18}, + kernel_config_params{256, 18}, 8, block_radix_rank_algorithm::match }; @@ -1780,8 +1780,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, + kernel_config_params{512, 4}, + kernel_config_params{512, 4}, 8, block_radix_rank_algorithm::match }; @@ -1792,8 +1792,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 6}, - kernel_config_params{1024, 6}, + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, 8, block_radix_rank_algorithm::match }; @@ -1804,8 +1804,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, 8, block_radix_rank_algorithm::match }; @@ -1816,8 +1816,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, 8, block_radix_rank_algorithm::match }; @@ -1840,8 +1840,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 6}, - kernel_config_params{512, 6}, + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, 8, block_radix_rank_algorithm::match }; @@ -1864,8 +1864,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 6}, - kernel_config_params{1024, 6}, + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, 8, block_radix_rank_algorithm::match }; @@ -1876,8 +1876,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 22}, - kernel_config_params{256, 22}, + kernel_config_params{256, 16}, + kernel_config_params{256, 16}, 8, block_radix_rank_algorithm::match }; @@ -1888,8 +1888,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 22}, - kernel_config_params{256, 22}, + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, 8, block_radix_rank_algorithm::match }; @@ -1900,8 +1900,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 22}, - kernel_config_params{256, 22}, + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, 8, block_radix_rank_algorithm::match }; @@ -1912,8 +1912,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 22}, - kernel_config_params{256, 22}, + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, 8, block_radix_rank_algorithm::match }; @@ -1923,8 +1923,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, 8, block_radix_rank_algorithm::match }; @@ -1934,8 +1934,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, 8, block_radix_rank_algorithm::match }; @@ -1945,8 +1945,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 12}, - kernel_config_params{512, 12}, + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, 8, block_radix_rank_algorithm::match }; @@ -1968,8 +1968,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 16}, - kernel_config_params{1024, 16}, + kernel_config_params{256, 22}, + kernel_config_params{256, 22}, 8, block_radix_rank_algorithm::match }; @@ -1991,8 +1991,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, 8, block_radix_rank_algorithm::match }; @@ -2003,8 +2003,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, + kernel_config_params{512, 4}, + kernel_config_params{512, 4}, 8, block_radix_rank_algorithm::match }; @@ -2015,8 +2015,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 6}, - kernel_config_params{256, 6}, + kernel_config_params{512, 4}, + kernel_config_params{512, 4}, 8, block_radix_rank_algorithm::match }; @@ -2027,8 +2027,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 6}, - kernel_config_params{512, 6}, + kernel_config_params{512, 4}, + kernel_config_params{512, 4}, 8, block_radix_rank_algorithm::match }; @@ -2039,8 +2039,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 6}, - kernel_config_params{512, 6}, + kernel_config_params{512, 4}, + kernel_config_params{512, 4}, 8, block_radix_rank_algorithm::match }; @@ -2063,8 +2063,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 6}, - kernel_config_params{512, 6}, + kernel_config_params{512, 4}, + kernel_config_params{512, 4}, 8, block_radix_rank_algorithm::match }; @@ -2075,8 +2075,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 12}, - kernel_config_params{512, 12}, + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, 8, block_radix_rank_algorithm::match }; @@ -2087,8 +2087,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, 8, block_radix_rank_algorithm::match }; @@ -2099,8 +2099,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, 8, block_radix_rank_algorithm::match }; @@ -2123,8 +2123,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, 8, block_radix_rank_algorithm::match }; @@ -2147,8 +2147,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, 8, block_radix_rank_algorithm::match }; @@ -2159,8 +2159,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 18}, - kernel_config_params{512, 18}, + kernel_config_params{256, 16}, + kernel_config_params{256, 16}, 8, block_radix_rank_algorithm::match }; @@ -2171,8 +2171,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 16}, - kernel_config_params{256, 16}, + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, 8, block_radix_rank_algorithm::match }; @@ -2183,8 +2183,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 22}, - kernel_config_params{256, 22}, + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, 8, block_radix_rank_algorithm::match }; @@ -2195,9 +2195,9 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 22}, - kernel_config_params{256, 22}, - 7, + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, + 8, block_radix_rank_algorithm::match }; } @@ -2219,8 +2219,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 18}, - kernel_config_params{256, 18}, + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, 8, block_radix_rank_algorithm::match }; @@ -2231,8 +2231,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 18}, - kernel_config_params{256, 18}, + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, 8, block_radix_rank_algorithm::match }; @@ -2243,8 +2243,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 18}, - kernel_config_params{1024, 18}, + kernel_config_params{256, 22}, + kernel_config_params{256, 22}, 8, block_radix_rank_algorithm::match }; @@ -2267,8 +2267,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 22}, - kernel_config_params{256, 22}, + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, 8, block_radix_rank_algorithm::match }; @@ -2278,9 +2278,9 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, - 8, + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + 8, block_radix_rank_algorithm::match }; } @@ -2289,8 +2289,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 6}, - kernel_config_params{512, 6}, + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, 8, block_radix_rank_algorithm::match }; @@ -2300,8 +2300,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 22}, - kernel_config_params{256, 22}, + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, 8, block_radix_rank_algorithm::match }; @@ -2311,10 +2311,10 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 18}, - kernel_config_params{1024, 18}, - 8, - block_radix_rank_algorithm::match + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, + 4, + block_radix_rank_algorithm::basic }; } // Based on key_type = int8_t, value_type = int8_t @@ -2323,8 +2323,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 22}, - kernel_config_params{1024, 22}, + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, 4, block_radix_rank_algorithm::basic }; @@ -2334,10 +2334,10 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 16}, - kernel_config_params{256, 16}, - 8, - block_radix_rank_algorithm::match + kernel_config_params{1024, 22}, + kernel_config_params{1024, 22}, + 4, + block_radix_rank_algorithm::basic }; } // Default case if none of the conditions match @@ -2356,8 +2356,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, 8, block_radix_rank_algorithm::match }; @@ -2368,8 +2368,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, 8, block_radix_rank_algorithm::match }; @@ -2380,8 +2380,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, 8, block_radix_rank_algorithm::match }; @@ -2392,8 +2392,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, 8, block_radix_rank_algorithm::match }; @@ -2404,8 +2404,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, 8, block_radix_rank_algorithm::match }; @@ -2416,8 +2416,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, 8, block_radix_rank_algorithm::match }; @@ -2428,8 +2428,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 6}, - kernel_config_params{512, 6}, + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, 8, block_radix_rank_algorithm::match }; @@ -2440,8 +2440,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 6}, - kernel_config_params{512, 6}, + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, 8, block_radix_rank_algorithm::match }; @@ -2452,8 +2452,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 12}, - kernel_config_params{512, 12}, + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, 8, block_radix_rank_algorithm::match }; @@ -2464,8 +2464,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 6}, - kernel_config_params{1024, 6}, + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, 8, block_radix_rank_algorithm::match }; @@ -2476,8 +2476,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 6}, - kernel_config_params{1024, 6}, + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, 8, block_radix_rank_algorithm::match }; @@ -2488,8 +2488,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 12}, - kernel_config_params{512, 12}, + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, 8, block_radix_rank_algorithm::match }; @@ -2499,8 +2499,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, 8, block_radix_rank_algorithm::match }; @@ -2510,8 +2510,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, 8, block_radix_rank_algorithm::match }; @@ -2521,8 +2521,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 12}, - kernel_config_params{512, 12}, + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, 8, block_radix_rank_algorithm::match }; @@ -2555,8 +2555,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 16}, - kernel_config_params{512, 16}, + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, 8, block_radix_rank_algorithm::match }; @@ -2567,8 +2567,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, 8, block_radix_rank_algorithm::match }; @@ -2579,8 +2579,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 6}, - kernel_config_params{512, 6}, + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, 8, block_radix_rank_algorithm::match }; @@ -2591,8 +2591,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 6}, - kernel_config_params{512, 6}, + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, 8, block_radix_rank_algorithm::match }; @@ -2603,8 +2603,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 6}, - kernel_config_params{512, 6}, + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, 8, block_radix_rank_algorithm::match }; @@ -2615,8 +2615,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 6}, - kernel_config_params{512, 6}, + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, 8, block_radix_rank_algorithm::match }; @@ -2627,8 +2627,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 6}, - kernel_config_params{512, 6}, + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, 8, block_radix_rank_algorithm::match }; @@ -2639,8 +2639,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, 8, block_radix_rank_algorithm::match }; @@ -2663,8 +2663,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, 8, block_radix_rank_algorithm::match }; @@ -2675,8 +2675,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, 8, block_radix_rank_algorithm::match }; @@ -2687,8 +2687,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, 8, block_radix_rank_algorithm::match }; @@ -2699,8 +2699,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, 8, block_radix_rank_algorithm::match }; @@ -2711,8 +2711,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 6}, - kernel_config_params{512, 6}, + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, 8, block_radix_rank_algorithm::match }; @@ -2723,8 +2723,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 6}, - kernel_config_params{512, 6}, + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, 8, block_radix_rank_algorithm::match }; @@ -2735,8 +2735,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 6}, - kernel_config_params{1024, 6}, + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, 8, block_radix_rank_algorithm::match }; @@ -2747,8 +2747,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 12}, - kernel_config_params{512, 12}, + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, 8, block_radix_rank_algorithm::match }; @@ -2759,8 +2759,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 6}, - kernel_config_params{1024, 6}, + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, 8, block_radix_rank_algorithm::match }; @@ -2771,8 +2771,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 12}, - kernel_config_params{512, 12}, + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, 8, block_radix_rank_algorithm::match }; @@ -2783,8 +2783,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 6}, - kernel_config_params{512, 6}, + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, 8, block_radix_rank_algorithm::match }; @@ -2795,8 +2795,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 6}, - kernel_config_params{1024, 6}, + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, 8, block_radix_rank_algorithm::match }; @@ -2807,9 +2807,9 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 6}, - kernel_config_params{1024, 6}, - 6, + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, + 8, block_radix_rank_algorithm::match }; } @@ -2831,8 +2831,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 22}, - kernel_config_params{512, 22}, + kernel_config_params{1024, 12}, + kernel_config_params{1024, 12}, 8, block_radix_rank_algorithm::match }; @@ -2854,8 +2854,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, 8, block_radix_rank_algorithm::match }; @@ -2865,8 +2865,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 6}, - kernel_config_params{512, 6}, + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, 8, block_radix_rank_algorithm::match }; @@ -2876,8 +2876,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 12}, - kernel_config_params{512, 12}, + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, 8, block_radix_rank_algorithm::match }; @@ -2899,8 +2899,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 18}, - kernel_config_params{512, 18}, + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, 4, block_radix_rank_algorithm::basic }; @@ -2923,7 +2923,7 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< template constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< std::is_same>::value, + comp_target>::value, radix_sort_onesweep_config_params> { // Based on key_type = double, value_type = rocprim::int128_t @@ -2932,8 +2932,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 32}, - kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, 8, block_radix_rank_algorithm::match }; @@ -2944,8 +2944,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 32}, - kernel_config_params{1024, 6}, + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, 8, block_radix_rank_algorithm::match }; @@ -2956,8 +2956,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 32}, - kernel_config_params{512, 12}, + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, 8, block_radix_rank_algorithm::match }; @@ -2968,8 +2968,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 32}, - kernel_config_params{512, 12}, + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, 8, block_radix_rank_algorithm::match }; @@ -2980,8 +2980,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 32}, - kernel_config_params{1024, 6}, + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, 8, block_radix_rank_algorithm::match }; @@ -2992,8 +2992,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 32}, - kernel_config_params{1024, 6}, + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, 8, block_radix_rank_algorithm::match }; @@ -3004,8 +3004,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 32}, - kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, 8, block_radix_rank_algorithm::match }; @@ -3016,8 +3016,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 32}, - kernel_config_params{1024, 6}, + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, 8, block_radix_rank_algorithm::match }; @@ -3028,8 +3028,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 32}, - kernel_config_params{1024, 12}, + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, 8, block_radix_rank_algorithm::match }; @@ -3040,8 +3040,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 12}, - kernel_config_params{1024, 12}, + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, 8, block_radix_rank_algorithm::match }; @@ -3052,8 +3052,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 32}, - kernel_config_params{512, 12}, + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, 8, block_radix_rank_algorithm::match }; @@ -3064,8 +3064,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 32}, - kernel_config_params{1024, 12}, + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, 8, block_radix_rank_algorithm::match }; @@ -3075,8 +3075,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 32}, - kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, 8, block_radix_rank_algorithm::match }; @@ -3086,8 +3086,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 32}, - kernel_config_params{1024, 6}, + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, 8, block_radix_rank_algorithm::match }; @@ -3097,8 +3097,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 32}, - kernel_config_params{1024, 12}, + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, 8, block_radix_rank_algorithm::match }; @@ -3108,8 +3108,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 32}, - kernel_config_params{512, 22}, + kernel_config_params{256, 22}, + kernel_config_params{256, 22}, 8, block_radix_rank_algorithm::match }; @@ -3120,8 +3120,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 32}, - kernel_config_params{512, 22}, + kernel_config_params{256, 16}, + kernel_config_params{256, 16}, 8, block_radix_rank_algorithm::match }; @@ -3131,8 +3131,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 32}, - kernel_config_params{1024, 12}, + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, 8, block_radix_rank_algorithm::match }; @@ -3143,8 +3143,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 32}, - kernel_config_params{512, 6}, + kernel_config_params{512, 4}, + kernel_config_params{512, 4}, 8, block_radix_rank_algorithm::match }; @@ -3155,8 +3155,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 32}, - kernel_config_params{512, 6}, + kernel_config_params{512, 4}, + kernel_config_params{512, 4}, 8, block_radix_rank_algorithm::match }; @@ -3167,8 +3167,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 32}, - kernel_config_params{256, 12}, + kernel_config_params{512, 4}, + kernel_config_params{512, 4}, 8, block_radix_rank_algorithm::match }; @@ -3179,8 +3179,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 32}, - kernel_config_params{512, 6}, + kernel_config_params{256, 8}, + kernel_config_params{256, 8}, 8, block_radix_rank_algorithm::match }; @@ -3191,8 +3191,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 32}, - kernel_config_params{512, 6}, + kernel_config_params{256, 8}, + kernel_config_params{256, 8}, 8, block_radix_rank_algorithm::match }; @@ -3203,8 +3203,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 32}, - kernel_config_params{512, 6}, + kernel_config_params{256, 6}, + kernel_config_params{256, 6}, 8, block_radix_rank_algorithm::match }; @@ -3215,8 +3215,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 32}, - kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, 8, block_radix_rank_algorithm::match }; @@ -3227,8 +3227,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 32}, - kernel_config_params{1024, 6}, + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, 8, block_radix_rank_algorithm::match }; @@ -3239,8 +3239,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 32}, - kernel_config_params{1024, 6}, + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, 8, block_radix_rank_algorithm::match }; @@ -3251,8 +3251,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 32}, - kernel_config_params{1024, 6}, + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, 8, block_radix_rank_algorithm::match }; @@ -3263,8 +3263,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 32}, - kernel_config_params{1024, 6}, + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, 8, block_radix_rank_algorithm::match }; @@ -3275,8 +3275,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 32}, - kernel_config_params{1024, 6}, + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, 8, block_radix_rank_algorithm::match }; @@ -3287,8 +3287,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 32}, - kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, 8, block_radix_rank_algorithm::match }; @@ -3299,8 +3299,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 32}, - kernel_config_params{1024, 6}, + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, 8, block_radix_rank_algorithm::match }; @@ -3311,8 +3311,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 32}, - kernel_config_params{1024, 12}, + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, 8, block_radix_rank_algorithm::match }; @@ -3323,8 +3323,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 32}, - kernel_config_params{1024, 12}, + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, 8, block_radix_rank_algorithm::match }; @@ -3335,8 +3335,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 32}, - kernel_config_params{1024, 12}, + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, 8, block_radix_rank_algorithm::match }; @@ -3347,8 +3347,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 32}, - kernel_config_params{1024, 12}, + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, 8, block_radix_rank_algorithm::match }; @@ -3359,8 +3359,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 32}, - kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, 8, block_radix_rank_algorithm::match }; @@ -3371,8 +3371,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 32}, - kernel_config_params{1024, 6}, + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, 8, block_radix_rank_algorithm::match }; @@ -3383,8 +3383,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 12}, - kernel_config_params{1024, 12}, + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, 8, block_radix_rank_algorithm::match }; @@ -3395,9 +3395,9 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 32}, - kernel_config_params{512, 22}, - 6, + kernel_config_params{256, 22}, + kernel_config_params{256, 22}, + 8, block_radix_rank_algorithm::match }; } @@ -3407,8 +3407,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 22}, - kernel_config_params{1024, 22}, + kernel_config_params{256, 22}, + kernel_config_params{256, 22}, 8, block_radix_rank_algorithm::match }; @@ -3419,8 +3419,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 32}, - kernel_config_params{1024, 12}, + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, 8, block_radix_rank_algorithm::match }; @@ -3430,8 +3430,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 32}, - kernel_config_params{512, 6}, + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, 8, block_radix_rank_algorithm::match }; @@ -3441,8 +3441,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 32}, - kernel_config_params{1024, 6}, + kernel_config_params{512, 6}, + kernel_config_params{512, 6}, 8, block_radix_rank_algorithm::match }; @@ -3452,8 +3452,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 32}, - kernel_config_params{1024, 12}, + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, 8, block_radix_rank_algorithm::match }; @@ -3463,8 +3463,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 32}, - kernel_config_params{1024, 22}, + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, 8, block_radix_rank_algorithm::match }; @@ -3475,8 +3475,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 22}, - kernel_config_params{1024, 22}, + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, 4, block_radix_rank_algorithm::basic }; @@ -3486,8 +3486,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 32}, - kernel_config_params{1024, 12}, + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, 8, block_radix_rank_algorithm::match }; @@ -3499,7 +3499,7 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< template constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< std::is_same>::value, + comp_target>::value, radix_sort_onesweep_config_params> { // Based on key_type = double, value_type = rocprim::int128_t @@ -3508,8 +3508,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, 8, block_radix_rank_algorithm::match }; @@ -3520,8 +3520,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 12}, - kernel_config_params{512, 12}, + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, 8, block_radix_rank_algorithm::match }; @@ -3532,7 +3532,7 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 12}, + kernel_config_params{512, 32}, kernel_config_params{512, 12}, 8, block_radix_rank_algorithm::match @@ -3544,7 +3544,7 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 12}, + kernel_config_params{512, 32}, kernel_config_params{512, 12}, 8, block_radix_rank_algorithm::match @@ -3556,8 +3556,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 12}, - kernel_config_params{512, 12}, + kernel_config_params{1024, 32}, + kernel_config_params{1024, 6}, 8, block_radix_rank_algorithm::match }; @@ -3580,8 +3580,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, 8, block_radix_rank_algorithm::match }; @@ -3592,8 +3592,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 12}, - kernel_config_params{512, 12}, + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, 8, block_radix_rank_algorithm::match }; @@ -3604,8 +3604,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 12}, - kernel_config_params{1024, 12}, + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, 8, block_radix_rank_algorithm::match }; @@ -3616,8 +3616,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 12}, - kernel_config_params{1024, 12}, + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, 8, block_radix_rank_algorithm::match }; @@ -3628,8 +3628,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 12}, - kernel_config_params{1024, 12}, + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, 8, block_radix_rank_algorithm::match }; @@ -3640,8 +3640,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 12}, - kernel_config_params{1024, 12}, + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, 8, block_radix_rank_algorithm::match }; @@ -3651,8 +3651,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, 8, block_radix_rank_algorithm::match }; @@ -3662,8 +3662,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 12}, - kernel_config_params{512, 12}, + kernel_config_params{1024, 32}, + kernel_config_params{1024, 6}, 8, block_radix_rank_algorithm::match }; @@ -3673,8 +3673,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 12}, - kernel_config_params{1024, 12}, + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, 8, block_radix_rank_algorithm::match }; @@ -3684,8 +3684,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 18}, - kernel_config_params{1024, 18}, + kernel_config_params{512, 32}, + kernel_config_params{512, 22}, 8, block_radix_rank_algorithm::match }; @@ -3696,8 +3696,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 18}, - kernel_config_params{512, 18}, + kernel_config_params{512, 32}, + kernel_config_params{512, 22}, 8, block_radix_rank_algorithm::match }; @@ -3707,8 +3707,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 18}, - kernel_config_params{256, 18}, + kernel_config_params{1024, 32}, + kernel_config_params{1024, 12}, 8, block_radix_rank_algorithm::match }; @@ -3719,8 +3719,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 6}, - kernel_config_params{512, 6}, + kernel_config_params{512, 32}, + kernel_config_params{512, 6}, 8, block_radix_rank_algorithm::match }; @@ -3731,8 +3731,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 6}, - kernel_config_params{512, 6}, + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, 8, block_radix_rank_algorithm::match }; @@ -3743,8 +3743,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 6}, - kernel_config_params{512, 6}, + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, 8, block_radix_rank_algorithm::match }; @@ -3755,8 +3755,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 6}, - kernel_config_params{512, 6}, + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, 8, block_radix_rank_algorithm::match }; @@ -3767,8 +3767,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 6}, - kernel_config_params{512, 6}, + kernel_config_params{512, 32}, + kernel_config_params{512, 6}, 8, block_radix_rank_algorithm::match }; @@ -3779,8 +3779,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 6}, - kernel_config_params{512, 6}, + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, 8, block_radix_rank_algorithm::match }; @@ -3791,8 +3791,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, 8, block_radix_rank_algorithm::match }; @@ -3803,8 +3803,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 12}, - kernel_config_params{512, 12}, + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, 8, block_radix_rank_algorithm::match }; @@ -3815,8 +3815,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 12}, - kernel_config_params{512, 12}, + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, 8, block_radix_rank_algorithm::match }; @@ -3827,8 +3827,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 12}, - kernel_config_params{512, 12}, + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, 8, block_radix_rank_algorithm::match }; @@ -3839,8 +3839,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 12}, - kernel_config_params{512, 12}, + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, 8, block_radix_rank_algorithm::match }; @@ -3863,8 +3863,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, 8, block_radix_rank_algorithm::match }; @@ -3875,8 +3875,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 12}, - kernel_config_params{512, 12}, + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, 8, block_radix_rank_algorithm::match }; @@ -3887,8 +3887,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 12}, - kernel_config_params{1024, 12}, + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, 8, block_radix_rank_algorithm::match }; @@ -3899,8 +3899,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 12}, - kernel_config_params{1024, 12}, + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, 8, block_radix_rank_algorithm::match }; @@ -3923,8 +3923,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 16}, - kernel_config_params{512, 16}, + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, 8, block_radix_rank_algorithm::match }; @@ -3935,8 +3935,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, 8, block_radix_rank_algorithm::match }; @@ -3947,8 +3947,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 12}, - kernel_config_params{512, 12}, + kernel_config_params{1024, 32}, + kernel_config_params{1024, 6}, 8, block_radix_rank_algorithm::match }; @@ -3959,8 +3959,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 12}, - kernel_config_params{1024, 12}, + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, 8, block_radix_rank_algorithm::match }; @@ -3971,9 +3971,9 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 18}, - kernel_config_params{1024, 18}, - 8, + kernel_config_params{512, 32}, + kernel_config_params{512, 22}, + 6, block_radix_rank_algorithm::match }; } @@ -3983,8 +3983,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 18}, - kernel_config_params{512, 18}, + kernel_config_params{1024, 22}, + kernel_config_params{1024, 22}, 8, block_radix_rank_algorithm::match }; @@ -3995,8 +3995,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 18}, - kernel_config_params{256, 18}, + kernel_config_params{1024, 32}, + kernel_config_params{1024, 12}, 8, block_radix_rank_algorithm::match }; @@ -4006,8 +4006,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 12}, - kernel_config_params{256, 12}, + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, 8, block_radix_rank_algorithm::match }; @@ -4017,8 +4017,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 8) && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 12}, - kernel_config_params{512, 12}, + kernel_config_params{1024, 32}, + kernel_config_params{1024, 6}, 8, block_radix_rank_algorithm::match }; @@ -4028,8 +4028,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 12}, - kernel_config_params{1024, 12}, + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, 8, block_radix_rank_algorithm::match }; @@ -4039,8 +4039,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 2) && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 18}, - kernel_config_params{512, 18}, + kernel_config_params{1024, 32}, + kernel_config_params{1024, 22}, 8, block_radix_rank_algorithm::match }; @@ -4051,10 +4051,10 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 16}, - kernel_config_params{256, 16}, - 8, - block_radix_rank_algorithm::match + kernel_config_params{1024, 22}, + kernel_config_params{1024, 22}, + 4, + block_radix_rank_algorithm::basic }; } // Based on key_type = int8_t, value_type = empty_type @@ -4062,8 +4062,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 18}, - kernel_config_params{256, 18}, + kernel_config_params{1024, 32}, + kernel_config_params{1024, 12}, 8, block_radix_rank_algorithm::match }; @@ -4084,8 +4084,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 32}, - kernel_config_params{512, 6}, + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, 8, block_radix_rank_algorithm::match }; @@ -4096,8 +4096,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 32}, - kernel_config_params{1024, 6}, + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, 8, block_radix_rank_algorithm::match }; @@ -4144,8 +4144,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 32}, - kernel_config_params{1024, 6}, + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, 8, block_radix_rank_algorithm::match }; @@ -4156,8 +4156,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 32}, - kernel_config_params{512, 6}, + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, 8, block_radix_rank_algorithm::match }; @@ -4168,8 +4168,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 32}, - kernel_config_params{1024, 6}, + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, 8, block_radix_rank_algorithm::match }; @@ -4180,8 +4180,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 32}, - kernel_config_params{1024, 12}, + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, 8, block_radix_rank_algorithm::match }; @@ -4192,8 +4192,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 12}, - kernel_config_params{1024, 12}, + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, 8, block_radix_rank_algorithm::match }; @@ -4204,8 +4204,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 32}, - kernel_config_params{512, 12}, + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, 8, block_radix_rank_algorithm::match }; @@ -4216,8 +4216,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 32}, - kernel_config_params{1024, 12}, + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, 8, block_radix_rank_algorithm::match }; @@ -4227,8 +4227,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 32}, - kernel_config_params{512, 6}, + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, 8, block_radix_rank_algorithm::match }; @@ -4249,8 +4249,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 32}, - kernel_config_params{1024, 12}, + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, 8, block_radix_rank_algorithm::match }; @@ -4307,8 +4307,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 32}, - kernel_config_params{512, 6}, + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, 8, block_radix_rank_algorithm::match }; @@ -4319,8 +4319,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{256, 32}, - kernel_config_params{256, 12}, + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, 8, block_radix_rank_algorithm::match }; @@ -4331,8 +4331,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 32}, - kernel_config_params{512, 6}, + kernel_config_params{1024, 4}, + kernel_config_params{1024, 4}, 8, block_radix_rank_algorithm::match }; @@ -4355,8 +4355,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 32}, - kernel_config_params{512, 6}, + kernel_config_params{512, 8}, + kernel_config_params{512, 8}, 8, block_radix_rank_algorithm::match }; @@ -4367,8 +4367,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 32}, - kernel_config_params{512, 6}, + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, 8, block_radix_rank_algorithm::match }; @@ -4379,8 +4379,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 32}, - kernel_config_params{1024, 6}, + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, 8, block_radix_rank_algorithm::match }; @@ -4391,8 +4391,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 32}, - kernel_config_params{1024, 6}, + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, 8, block_radix_rank_algorithm::match }; @@ -4403,8 +4403,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 32}, - kernel_config_params{1024, 6}, + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, 8, block_radix_rank_algorithm::match }; @@ -4415,8 +4415,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 32}, - kernel_config_params{1024, 6}, + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, 8, block_radix_rank_algorithm::match }; @@ -4427,8 +4427,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 32}, - kernel_config_params{1024, 6}, + kernel_config_params{512, 12}, + kernel_config_params{512, 12}, 8, block_radix_rank_algorithm::match }; @@ -4439,8 +4439,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 32}, - kernel_config_params{512, 6}, + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, 8, block_radix_rank_algorithm::match }; @@ -4451,8 +4451,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 4))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 32}, - kernel_config_params{1024, 6}, + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, 8, block_radix_rank_algorithm::match }; @@ -4463,8 +4463,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 32}, - kernel_config_params{1024, 12}, + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, 8, block_radix_rank_algorithm::match }; @@ -4475,8 +4475,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 1))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 32}, - kernel_config_params{1024, 12}, + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, 8, block_radix_rank_algorithm::match }; @@ -4487,8 +4487,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (!std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 32}, - kernel_config_params{1024, 12}, + kernel_config_params{512, 16}, + kernel_config_params{512, 16}, 8, block_radix_rank_algorithm::match }; @@ -4499,8 +4499,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (std::is_same::value))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 32}, - kernel_config_params{1024, 12}, + kernel_config_params{1024, 8}, + kernel_config_params{1024, 8}, 8, block_radix_rank_algorithm::match }; @@ -4511,8 +4511,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 32}, - kernel_config_params{512, 6}, + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, 8, block_radix_rank_algorithm::match }; @@ -4535,8 +4535,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 12}, - kernel_config_params{1024, 12}, + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, 8, block_radix_rank_algorithm::match }; @@ -4582,8 +4582,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 16) && (sizeof(value_type) > 8))) { return radix_sort_onesweep_config_params{ - kernel_config_params{512, 32}, - kernel_config_params{512, 6}, + kernel_config_params{1024, 6}, + kernel_config_params{1024, 6}, 8, block_radix_rank_algorithm::match }; @@ -4604,8 +4604,8 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< && (sizeof(value_type) <= 4) && (sizeof(value_type) > 2))) { return radix_sort_onesweep_config_params{ - kernel_config_params{1024, 32}, - kernel_config_params{1024, 12}, + kernel_config_params{1024, 16}, + kernel_config_params{1024, 16}, 8, block_radix_rank_algorithm::match }; @@ -4664,11 +4664,11 @@ constexpr auto radix_sort_onesweep_config_picker() -> std::enable_if_t< using radix_sort_onesweep_targets = comp_targets, comp_target, - comp_target, + comp_target, comp_target, comp_target, + comp_target, comp_target, - comp_target, comp_target, comp_target>; From 43a9915ab77407f5c9eae0d96d14d3ba663cb9e7 Mon Sep 17 00:00:00 2001 From: Nick Breed Date: Mon, 17 Nov 2025 06:39:49 +0000 Subject: [PATCH 09/26] Resolve "New Config system tests" --- .../include/rocprim/device/config_types.hpp | 20 +- .../test/rocprim/test_config_dispatch.cpp | 284 ++++++++++++++++++ 2 files changed, 300 insertions(+), 4 deletions(-) diff --git a/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp b/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp index e812fdae740..5c411040575 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp @@ -596,7 +596,8 @@ template - class LaunchSelector> + class LaunchSelector, + bool PassTarget = false> ROCPRIM_KERNEL __launch_bounds__((LaunchSelector::block_size)) void trampoline_kernel(Kernel kernel) { @@ -609,7 +610,14 @@ void trampoline_kernel(Kernel kernel) if constexpr(Target::i == device_arch_target.i) #endif { - kernel(ArchConfig{}); + if constexpr(PassTarget) + { + kernel(ArchConfig{}, Target{}); + } + else + { + kernel(ArchConfig{}); + } } #if !defined(ROCPRIM_TARGET_SPIRV) || ROCPRIM_TARGET_SPIRV == 0 else @@ -622,6 +630,7 @@ void trampoline_kernel(Kernel kernel) template class LaunchSelector = default_config_static_selector, + bool PassTarget = false, class Kernel> auto make_launch_plan(target target_current, Kernel kernel) -> launch_plan { @@ -641,7 +650,8 @@ auto make_launch_plan(target target_current, Kernel kernel) -> launch_plan; + LaunchSelector, + PassTarget>; } }); @@ -668,11 +678,13 @@ hipError_t execute_launch_plan(target_arch arch, template class LaunchSelector = default_config_static_selector, + bool PassTarget = false, class Kernel> hipError_t execute_launch_plan( target t, Kernel kernel, dim3 grid_size, dim3 block_size, size_t shmem, hipStream_t stream) { - const auto launch_plan = make_launch_plan(t, kernel); + const auto launch_plan + = make_launch_plan(t, kernel); launch_plan.launch(grid_size, block_size, shmem, stream); return hipGetLastError(); } diff --git a/projects/rocprim/test/rocprim/test_config_dispatch.cpp b/projects/rocprim/test/rocprim/test_config_dispatch.cpp index 29de22eb0e6..547c85e993a 100644 --- a/projects/rocprim/test/rocprim/test_config_dispatch.cpp +++ b/projects/rocprim/test/rocprim/test_config_dispatch.cpp @@ -160,3 +160,287 @@ TEST(RocprimConfigDispatchTests, DeviceIdFromStream) ASSERT_EQ(result, device_id); } #endif + +namespace rocprim +{ +namespace detail +{ + +using Targets + = comp_targets, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target, + comp_target>; + +struct test_config_params +{ + kernel_config_params kernel_config{}; + + __host__ __device__ + inline bool + operator==(test_config_params other) const + { + return (kernel_config.block_size == other.kernel_config.block_size) + && (kernel_config.items_per_thread == other.kernel_config.items_per_thread); + } +}; + +template +struct TestSelector +{ + using targets = TargetsInput; + using param_type = test_config_params; + + param_type params; + + template + constexpr TestSelector(Target) + : params(test_config_params{ + {BlockSize, ItemsPerThread} + }) + {} +}; + +template +struct TestSelector2 +{ + using targets = TargetsInput; + using param_type = test_config_params; + + param_type params; + + template + constexpr param_type picker_helper() + { + constexpr auto arch = Target::i; + constexpr auto gen = Target::g; + constexpr auto gpu = Target::s; + + // Assign unique configs per architecture/gen/GPU + if constexpr(arch == target_arch::gfx1030 && gen == gen::rdna2 && gpu == gpu::rx6900) + { + return param_type{ + {64, 1} + }; + } + else if constexpr(arch == target_arch::gfx1100 && gen == gen::rdna3 && gpu == gpu::rx7900) + { + return param_type{ + {128, 2} + }; + } + else if constexpr(arch == target_arch::gfx908 && gen == gen::cdna1 && gpu == gpu::mi100) + { + return param_type{ + {192, 3} + }; + } + else if constexpr(arch == target_arch::gfx90a && gen == gen::cdna2 && gpu == gpu::mi210) + { + return param_type{ + {256, 4} + }; + } + else if constexpr(arch == target_arch::gfx942 && gen == gen::cdna3 && gpu == gpu::mi300x) + { + return param_type{ + {320, 5} + }; + } + else if constexpr(arch == target_arch::gfx942 && gen == gen::cdna3 && gpu == gpu::mi300a) + { + return param_type{ + {384, 6} + }; + } + else if constexpr(arch == target_arch::gfx1200 && gen == gen::rdna4 && gpu == gpu::rx9060) + { + return param_type{ + {448, 7} + }; + } + else if constexpr(arch == target_arch::gfx1201 && gen == gen::rdna4 && gpu == gpu::rx9070) + { + return param_type{ + {512, 8} + }; + } + else if constexpr(arch == target_arch::gfx950 && gen == gen::cdna4 && gpu == gpu::mi350x) + { + return param_type{ + {576, 9} + }; + } + else + { + // Default fallback for unknown targets + return param_type{ + {1, 1} + }; + } + } + + template + constexpr TestSelector2(Target) : params(picker_helper()) + {} +}; + +} // namespace detail +} // namespace rocprim + +// This test needs to be changed once the selection logic changes for most_common_config. +TEST(RocprimConfigDispatchTests, MostCommonConfig) +{ + using namespace rocprim::detail; + + // 1. Default-constructed target (all unknowns) should return itself. + ASSERT_EQ(most_common_config(target()), target()); + // 2. Exact match (same arch & gpu) should return that target. + ASSERT_EQ(most_common_config(target(target_arch::gfx1200, gpu::rx9060)), + target(target_arch::gfx1200, gpu::rx9060)); + // 3. Unknown arch/gpu combination should return default unknown. + ASSERT_EQ(most_common_config(target(target_arch::gfx906, gpu::mi50)), target()); + // 4. Multiple entries with same arch — latest wins (mi300a over mi300x). + ASSERT_EQ(most_common_config(target(target_arch::gfx942, gpu::mi308x)), + target(target_arch::gfx942, gpu::mi300a)); + // 5. Target with same generation but unknown arch — picks matching gen target. + ASSERT_EQ(most_common_config(target(target_arch::gfx1102)), + target(target_arch::gfx1100, gpu::rx7900)); + // 6. Generation-only match (no arch/gpu known). + ASSERT_EQ(most_common_config(target(gen::rdna2)), + target(target_arch::gfx1030, gpu::rx6900)); + // 7. Arch-only match (gpu::generic, known arch). + ASSERT_EQ(most_common_config(target(target_arch::gfx90a)), + target(target_arch::gfx90a, gpu::mi210)); + // 8. If gen has multiple matching targets (cdna3), latest listed is chosen. + ASSERT_EQ(most_common_config(target(gen::cdna3)), + target(target_arch::gfx942, gpu::mi300a)); +} + +// This test needs to be changed once the selection logic changes for most_common_config. +TEST(RocprimConfigDispatchTests, DefaultSelectConfig) +{ + using namespace rocprim::detail; + + using Selector = TestSelector2; + using Params = typename Selector::param_type; + + auto cfg = [](target t) { return default_select_config(t); }; + + // 1. Default target (unknown) + ASSERT_EQ(cfg(target()), + (Params{ + {1, 1} + })); + + // 2. Exact match (gfx1200, rx9060) + ASSERT_EQ(cfg(target(target_arch::gfx1200, gpu::rx9060)), + (Params{ + {448, 7} + })); + + // 3. Unknown arch/gpu + ASSERT_EQ(cfg(target(target_arch::gfx906, gpu::mi50)), + (Params{ + {1, 1} + })); + + // 4. Multiple entries same arch (gfx942) — latest wins (mi300a) + ASSERT_EQ(cfg(target(target_arch::gfx942, gpu::mi308x)), + (Params{ + {384, 6} + })); + + // 5. Same gen but unknown arch (gfx1102) + ASSERT_EQ(cfg(target(target_arch::gfx1102)), + (Params{ + {128, 2} + })); + + // 6. Generation-only match (rdna2) + ASSERT_EQ(cfg(target(gen::rdna2)), + (Params{ + {64, 1} + })); + + // 7. Arch-only match (gfx90a) + ASSERT_EQ(cfg(target(target_arch::gfx90a)), + (Params{ + {256, 4} + })); + + // 8. Generation with multiple matches (cdna3 -> gfx942/mi300a) + ASSERT_EQ(cfg(target(gen::cdna3)), + (Params{ + {384, 6} + })); + + // 9. Other targets in the list + ASSERT_EQ(cfg(target(target_arch::gfx1201, gpu::rx9070)), + (Params{ + {512, 8} + })); + ASSERT_EQ(cfg(target(target_arch::gfx950, gpu::mi350x)), + (Params{ + {576, 9} + })); +} + +TEST(RocprimConfigDispatchTests, ExecuteLaunchPlan) +{ + using namespace rocprim::detail; + using EmptyTargets + = comp_targets>; + using Config = rocprim::default_config; + + constexpr unsigned int block_size = 256; + constexpr unsigned int ipt = 2; + + hipStream_t stream = 0; + + target_arch target_arch; + HIP_CHECK(host_target_arch(stream, target_arch)); + gpu target_gpu; + HIP_CHECK(host_target_gpu(stream, target_gpu)); + + const target current_target(target_arch, target_gpu); + + target* d_output; + HIP_CHECK(hipMalloc(&d_output, sizeof(target))); + + auto kernel = [=](auto arch_config, auto t) + { + (void)arch_config; + *d_output = target{t}; + }; + + HIP_CHECK( + (execute_launch_plan, + default_config_static_selector, + true>(current_target, kernel, dim3(1), dim3(block_size), 0, stream))); + + // Impossible target + target h_output{gen::cdna4, target_arch::gfx942, gpu::rx6900, rep::spirv}; + HIP_CHECK(hipMemcpy(&h_output, d_output, sizeof(target), hipMemcpyDeviceToHost)); + // Compared to targets with only unknown inside. + ASSERT_EQ(target(), h_output); + + HIP_CHECK( + (execute_launch_plan, + default_config_static_selector, + true>(current_target, kernel, dim3(1), dim3(block_size), 0, stream))); + + // Impossible target + h_output = target{gen::cdna4, target_arch::gfx942, gpu::rx6900, rep::spirv}; + HIP_CHECK(hipMemcpy(&h_output, d_output, sizeof(target), hipMemcpyDeviceToHost)); + // Should have the same targets as most_common_config. + ASSERT_EQ(most_common_config(current_target), h_output); +} From 73b13ac3edb59de62c24933f96f4abc4aba28066 Mon Sep 17 00:00:00 2001 From: Saiyang Zhang Date: Mon, 24 Nov 2025 08:35:50 +0000 Subject: [PATCH 10/26] Resolve "Consistency in config tags" --- .../device/detail/device_config_helper.hpp | 72 +------------------ .../rocprim/device/device_binary_search.hpp | 7 -- 2 files changed, 3 insertions(+), 76 deletions(-) diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/device_config_helper.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/device_config_helper.hpp index 36c8cd23414..fdfb1403ea7 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/device_config_helper.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/device_config_helper.hpp @@ -223,9 +223,6 @@ struct radix_sort_onesweep_config : detail::radix_sort_onesweep_config_params namespace detail { -struct reduce_config_tag -{}; - // Calculate kernel configurations, such that it will not exceed shared memory maximum template constexpr radix_sort_onesweep_config_params radix_sort_onesweep_config_params_base() @@ -261,8 +258,6 @@ template struct reduce_config : rocprim::detail::reduce_config_params { - /// \brief Identifies the algorithm associated to the config. - using tag = detail::reduce_config_tag; constexpr reduce_config() : rocprim::detail::reduce_config_params{ {BlockSize, ItemsPerThread, SizeLimit}, @@ -286,9 +281,6 @@ constexpr reduce_config_params reduce_config_params_base() }; }; -struct scan_config_tag -{}; - /// \brief Provides the kernel parameters for exclusive_scan and inclusive_scan based /// on autotuned configurations or user-provided configurations. struct scan_config_params @@ -317,8 +309,6 @@ template struct scan_config : ::rocprim::detail::scan_config_params { - /// \brief Identifies the algorithm associated to the config. - using tag = detail::scan_config_tag; #ifndef DOXYGEN_SHOULD_SKIP_THIS // Requirement dictated by init_lookback_scan_state_kernel. static_assert(BlockSize <= ROCPRIM_DEFAULT_MAX_BLOCK_SIZE, @@ -350,9 +340,6 @@ struct scan_config : ::rocprim::detail::scan_config_params namespace detail { -struct scan_by_key_config_tag -{}; - template constexpr scan_config_params scan_config_params_base() { @@ -396,8 +383,6 @@ template struct scan_by_key_config : ::rocprim::detail::scan_by_key_config_params { - /// \brief Identifies the algorithm associated to the config. - using tag = detail::scan_by_key_config_tag; #ifndef DOXYGEN_SHOULD_SKIP_THIS // Requirement dictated by init_lookback_scan_state_kernel. static_assert(BlockSize <= ROCPRIM_DEFAULT_MAX_BLOCK_SIZE, @@ -445,9 +430,6 @@ constexpr scan_by_key_config_params scan_by_key_config_params_base() }; }; -struct transform_config_tag -{}; - struct transform_config_params { kernel_config_params kernel_config = {0, 0}; @@ -458,8 +440,6 @@ struct transform_config_params namespace detail { -struct segmented_radix_sort_config_tag -{}; struct warp_sort_config_params { @@ -591,8 +571,6 @@ template struct segmented_radix_sort_config : public detail::segmented_radix_sort_config_params { - /// \brief Identifies the algorithm associated to the config. - using tag = detail::segmented_radix_sort_config_tag; #ifndef DOXYGEN_SHOULD_SKIP_THIS /// \brief Number of bits in iterations. @@ -659,8 +637,6 @@ template struct transform_config : public detail::transform_config_params { - /// \brief Identifies the algorithm associated to the config. - using tag = detail::transform_config_tag; #ifndef DOXYGEN_SHOULD_SKIP_THIS /// \brief Number of threads in a block. @@ -694,8 +670,6 @@ template struct transform_pointer_config : public detail::transform_config_params { - /// \brief Identifies the algorithm associated to the config. - using tag = detail::transform_config_tag; #ifndef DOXYGEN_SHOULD_SKIP_THIS /// \brief Number of threads in a block. @@ -744,13 +718,6 @@ constexpr transform_config_params transform_config_params_base() }; } -struct binary_search_config_tag : public transform_config_tag -{}; -struct upper_bound_config_tag : public transform_config_tag -{}; -struct lower_bound_config_tag : public transform_config_tag -{}; - } // namespace detail /// \brief Configuration for the device-level binary search operation. @@ -761,10 +728,7 @@ template struct binary_search_config : transform_config -{ - /// \brief Identifies the algorithm associated to the config. - using tag = detail::binary_search_config_tag; -}; +{}; /// \brief Configuration for the device-level upper bound operation. /// \tparam BlockSize Number of threads in a block. @@ -774,10 +738,7 @@ template struct upper_bound_config : transform_config -{ - /// \brief Identifies the algorithm associated to the config. - using tag = detail::upper_bound_config_tag; -}; +{}; /// \brief Configuration for the device-level lower bound operation. /// \tparam BlockSize Number of threads in a block. @@ -787,17 +748,11 @@ template struct lower_bound_config : transform_config -{ - /// \brief Identifies the algorithm associated to the config. - using tag = detail::lower_bound_config_tag; -}; +{}; namespace detail { -struct histogram_config_tag -{}; - template constexpr transform_config_params binary_search_config_params_base() { @@ -840,8 +795,6 @@ template struct histogram_config : detail::histogram_config_params { - /// \brief Identifies the algorithm associated to the config. - using tag = detail::histogram_config_tag; #ifndef DOXYGEN_SHOULD_SKIP_THIS using histogram = HistogramConfig; @@ -872,9 +825,6 @@ constexpr histogram_config_params histogram_config_params_base() return histogram_config_params{kernel_params, 1024, 2048, 3, kernel_params}; }; -struct adjacent_difference_config_tag -{}; - struct adjacent_difference_config_params { kernel_config_params kernel_config{}; @@ -897,8 +847,6 @@ template struct adjacent_difference_config : public detail::adjacent_difference_config_params { - /// \brief Identifies the algorithm associated to the config. - using tag = detail::adjacent_difference_config_tag; #ifndef DOXYGEN_SHOULD_SKIP_THIS static constexpr ::rocprim::block_load_method block_load_method = BlockLoadMethod; static constexpr ::rocprim::block_store_method block_store_method = BlockStoreMethod; @@ -936,9 +884,6 @@ constexpr adjacent_difference_config_params adjacent_difference_config_params_ba namespace detail { -struct batch_memcpy_config_tag -{}; - struct batch_memcpy_config_params { /// \brief Kernel config for thread- and warp-level copy @@ -975,8 +920,6 @@ template struct batch_memcpy_config : public detail::batch_memcpy_config_params { - /// \brief Identifies the algorithm associated to the config. - using tag = detail::batch_memcpy_config_tag; #ifndef DOXYGEN_SHOULD_SKIP_THIS static constexpr unsigned int non_blev_block_size = NonBlevBlockSize; static constexpr unsigned int non_blev_items_per_thread = NonBlevItemsPerThread; @@ -1276,8 +1219,6 @@ struct nth_element_config : public detail::nth_element_config_params namespace detail { -struct non_trivial_runs_config_tag -{}; struct non_trivial_runs_config_params { @@ -1302,8 +1243,6 @@ template struct non_trivial_runs_config : public detail::non_trivial_runs_config_params { - /// \brief Identifies the algorithm associated to the config. - using tag = detail::non_trivial_runs_config_tag; #ifndef DOXYGEN_DOCUMENTATION_BUILD /// \brief Number of threads in a block. static constexpr unsigned int block_size = BlockSize; @@ -1368,9 +1307,6 @@ struct find_first_of_config_params kernel_config_params kernel_config{}; }; -struct adjacent_find_config_tag -{}; - struct adjacent_find_config_params { kernel_config_params kernel_config{}; @@ -1401,8 +1337,6 @@ struct find_first_of_config : public detail::find_first_of_config_params template struct adjacent_find_config : public detail::adjacent_find_config_params { - /// \brief Identifies the algorithm associated to the config. - using tag = detail::adjacent_find_config_tag; #ifndef DOXYGEN_DOCUMENTATION_BUILD constexpr adjacent_find_config() : detail::adjacent_find_config_params{ diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_binary_search.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_binary_search.hpp index effaf4ccf5d..3a00502eef8 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_binary_search.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_binary_search.hpp @@ -209,9 +209,6 @@ hipError_t lower_bound(void * temporary_storage, hipStream_t stream = 0, bool debug_synchronous = false) { - static_assert(detail::is_default_or_has_tag::value, - "Config must be a specialization of struct template lower_bound_config"); - using value_type = typename std::iterator_traits::value_type; using output_type = typename std::iterator_traits::value_type; using selector = detail::lower_bound_config_selector; @@ -343,8 +340,6 @@ hipError_t upper_bound(void * temporary_storage, hipStream_t stream = 0, bool debug_synchronous = false) { - static_assert(detail::is_default_or_has_tag::value, - "Config must be a specialization of struct template upper_bound_config"); using value_type = typename std::iterator_traits::value_type; using output_type = typename std::iterator_traits::value_type; using selector = detail::upper_bound_config_selector; @@ -471,8 +466,6 @@ hipError_t binary_search(void * temporary_storage, hipStream_t stream = 0, bool debug_synchronous = false) { - static_assert(detail::is_default_or_has_tag::value, - "Config must be a specialization of struct template binary_search_config"); using value_type = typename std::iterator_traits::value_type; using output_type = typename std::iterator_traits::value_type; using selector = detail::binary_search_config_selector; From 629896f44b0ebc13a613f121200fb1bf0b88976b Mon Sep 17 00:00:00 2001 From: Nick Breed Date: Mon, 24 Nov 2025 09:10:26 +0000 Subject: [PATCH 11/26] Resolve "Remove all unused config functions old system" --- .../device/device_segmented_reduce.hpp | 71 ++++++----- .../include/rocprim/device/config_types.hpp | 119 +----------------- .../device/detail/device_nth_element.hpp | 64 +++++----- .../rocprim/device/detail/device_search.hpp | 56 +++++---- .../device/device_histogram_config.hpp | 4 +- .../rocprim/device/device_memcpy_config.hpp | 4 +- .../include/rocprim/device/device_merge.hpp | 2 +- .../rocprim/device/device_merge_sort.hpp | 8 +- .../device/device_merge_sort_config.hpp | 6 +- .../rocprim/device/device_nth_element.hpp | 58 +++++---- .../device/device_nth_element_config.hpp | 47 +++---- .../rocprim/device/device_partition.hpp | 2 +- .../device/device_radix_sort_config.hpp | 6 +- .../rocprim/device/device_search_config.hpp | 46 +++---- .../device_segmented_radix_sort_config.hpp | 4 +- 15 files changed, 183 insertions(+), 314 deletions(-) diff --git a/projects/hipcub/hipcub/include/hipcub/backend/rocprim/device/device_segmented_reduce.hpp b/projects/hipcub/hipcub/include/hipcub/backend/rocprim/device/device_segmented_reduce.hpp index fc2d02fed48..749960012c3 100644 --- a/projects/hipcub/hipcub/include/hipcub/backend/rocprim/device/device_segmented_reduce.hpp +++ b/projects/hipcub/hipcub/include/hipcub/backend/rocprim/device/device_segmented_reduce.hpp @@ -51,23 +51,24 @@ namespace detail { template -inline hipError_t launch_segmented_arg_minmax(::rocprim::detail::target_arch arch, - InputIterator input, - OutputIterator output, - OffsetIterator begin_offsets, - OffsetIterator end_offsets, - BinaryFunction reduce_op, - ResultType initial_value, - ResultType empty_value, - dim3 grid, - dim3 block, - size_t shmem, - hipStream_t stream) +inline hipError_t launch_segmented_arg_minmax(::rocprim::detail::target current_target, + InputIterator input, + OutputIterator output, + OffsetIterator begin_offsets, + OffsetIterator end_offsets, + BinaryFunction reduce_op, + ResultType initial_value, + ResultType empty_value, + dim3 grid, + dim3 block, + size_t shmem, + hipStream_t stream) { auto kernel = [=](auto arch_config) { @@ -103,7 +104,12 @@ inline hipError_t launch_segmented_arg_minmax(::rocprim::detail::target_arch arc } }; - return ::rocprim::detail::execute_launch_plan(arch, kernel, grid, block, shmem, stream); + return ::rocprim::detail::execute_launch_plan(current_target, + kernel, + grid, + block, + shmem, + stream); } /// Dispatch function similar to \p rocprim::segmented_reduce but writes \p empty_value for empty @@ -129,17 +135,24 @@ inline hipError_t segmented_arg_minmax(void* temporary_storage, using input_type = typename std::iterator_traits::value_type; using result_type = ::rocprim::accumulator_t; - using config = ::rocprim::detail::wrapped_reduce_config; + using selector = ::rocprim::detail::segmented_reduce_config_selector; ::rocprim::detail::target_arch target_arch; - hipError_t result = host_target_arch(stream, target_arch); + hipError_t result = ::rocprim::detail::host_target_arch(stream, target_arch); if(result != hipSuccess) { return result; } - const ::rocprim::detail::reduce_config_params params - = ::rocprim::detail::dispatch_target_arch(target_arch); + ::rocprim::detail::gpu target_gpu; + result = ::rocprim::detail::host_target_gpu(stream, target_gpu); + if(result != hipSuccess) + { + return result; + } + + const ::rocprim::detail::target current_target(target_arch, target_gpu); + const auto params = ::rocprim::detail::get_config(Config{}, current_target); const unsigned int block_size = params.kernel_config.block_size; if(temporary_storage == nullptr) @@ -160,18 +173,18 @@ inline hipError_t segmented_arg_minmax(void* temporary_storage, start = std::chrono::high_resolution_clock::now(); } ROCPRIM_RETURN_ON_ERROR( - launch_segmented_arg_minmax(target_arch, - input, - output, - begin_offsets, - end_offsets, - reduce_op, - static_cast(initial_value), - static_cast(empty_value), - dim3(segments), - dim3(block_size), - 0, - stream)); + launch_segmented_arg_minmax(current_target, + input, + output, + begin_offsets, + end_offsets, + reduce_op, + static_cast(initial_value), + static_cast(empty_value), + dim3(segments), + dim3(block_size), + 0, + stream)); HIPCUB_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("segmented_arg_minmax", segments, start); return hipSuccess; diff --git a/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp b/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp index 5c411040575..19ff896c9a1 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp @@ -426,42 +426,6 @@ constexpr target_arch device_target_arch() #endif } -template -struct default_config_selector -{ - static constexpr unsigned int block_size - = Config::template architecture_config::params.kernel_config.block_size; -}; - -template -struct target_config -{ - constexpr static auto params = Config::template architecture_config::params; - constexpr static auto wavefront = arch_wavefront_size(Arch); - constexpr static auto arch = Arch; -}; - -// trampoline_kernel that is fully specialized at compile-time for a single GPU architecture. -// By instantiating this template once per supported `target_arch`,the correct tuned config -// will be derived from the template. -template - class LaunchSelector> -ROCPRIM_KERNEL __launch_bounds__((LaunchSelector::block_size)) -void trampoline_kernel(Kernel kernel) -{ - using ArchConfig = target_config; - -#if !defined(ROCPRIM_TARGET_SPIRV) || ROCPRIM_TARGET_SPIRV == 0 - if constexpr(Arch == device_target_arch()) -#endif - { - kernel(ArchConfig{}); - } -} - template struct launch_plan { @@ -480,30 +444,6 @@ struct launch_plan } }; -template class LaunchSelector = default_config_selector> -auto make_launch_plan(target_arch arch, Kernel kernel) -> launch_plan -{ - std::optional tuned_kernel = std::nullopt; - - for_each_arch( - [&](auto arch_tag) - { - if(arch_tag != arch || tuned_kernel) - return; - - tuned_kernel = trampoline_kernel; - }); - - if(!tuned_kernel) - { - tuned_kernel = trampoline_kernel; - } - - return {tuned_kernel.value(), kernel}; -} - template constexpr target most_common_config(target target_current) { @@ -574,7 +514,7 @@ constexpr typename Selector::param_type get_config(Config config, target t) }; template -struct target_config2 +struct target_config { constexpr static auto params = get_config(Config{}, target{Target{}}); constexpr static auto wavefront = arch_wavefront_size(Target::i); @@ -585,7 +525,7 @@ template struct default_config_static_selector { static constexpr auto block_size - = target_config2::params.kernel_config.block_size; + = target_config::params.kernel_config.block_size; }; // trampoline_kernel that is fully specialized at compile-time for a single GPU architecture. @@ -601,7 +541,7 @@ template::block_size)) void trampoline_kernel(Kernel kernel) { - using ArchConfig = target_config2; + using ArchConfig = target_config; #if !defined(ROCPRIM_TARGET_SPIRV) || ROCPRIM_TARGET_SPIRV == 0 using Targets = typename Selector::targets; @@ -658,23 +598,6 @@ auto make_launch_plan(target target_current, Kernel kernel) -> launch_plan class LaunchSelector = default_config_selector> -hipError_t execute_launch_plan(target_arch arch, - Kernel kernel, - dim3 grid_size, - dim3 block_size, - size_t shmem, - hipStream_t stream) -{ - const auto launch_plan = make_launch_plan(arch, kernel); - launch_plan.launch(grid_size, block_size, shmem, stream); - return hipGetLastError(); -} - template class LaunchSelector = default_config_static_selector, @@ -689,42 +612,6 @@ hipError_t execute_launch_plan( return hipGetLastError(); } -#ifdef ROCPRIM_EXPERIMENTAL_SPIRV -template -#else -template -#endif -auto dispatch_target_arch([[maybe_unused]] const target_arch target_arch) -{ - if constexpr(!ForceUnknownArch) - { - switch(target_arch) - { - case target_arch::invalid: - assert(false && "Invalid target architecture selected at runtime."); - break; -#define X(ID) case target_arch::ID: return Config::template architecture_config::params - X(unknown); - X(gfx803); - X(gfx900); - X(gfx906); - X(gfx908); - X(gfx90a); - X(gfx942); - X(gfx950); - X(gfx1030); - X(gfx1100); - X(gfx1102); - X(gfx1152); - X(gfx1153); - X(gfx1200); - X(gfx1201); -#undef X - } - } - return Config::template architecture_config::params; -} - inline target_arch parse_gcn_arch(const char* arch_name) { static constexpr auto length = sizeof(hipDeviceProp_t::gcnArchName); diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/device_nth_element.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/device_nth_element.hpp index d5a4bf8f3f3..0b6173127e7 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/device_nth_element.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/device_nth_element.hpp @@ -542,13 +542,14 @@ ROCPRIM_DEVICE ROCPRIM_FORCE_INLINE void } template ROCPRIM_INLINE hipError_t - nth_element_keys_impl(detail::target_arch target_arch, + nth_element_keys_impl(target current_target, KeysIterator keys, typename std::iterator_traits::value_type* keys_buffer, typename std::iterator_traits::value_type* tree, @@ -620,12 +621,12 @@ hipError_t size, compare_function); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - find_splitters_kernel, - 1, - num_splitters, - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(current_target, + find_splitters_kernel, + 1, + num_splitters, + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("find_splitters_kernel", size, start); start_timer(); @@ -638,12 +639,12 @@ hipError_t equality_buckets, compare_function); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - count_bucket_sizes_kernel, - num_blocks, - num_threads_per_block, - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(current_target, + count_bucket_sizes_kernel, + num_blocks, + num_threads_per_block, + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("count_bucket_sizes_kernel", size, start); start_timer(); @@ -654,12 +655,13 @@ hipError_t equality_buckets, rank); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - find_nth_element_bucket_kernel, - 1, - num_buckets, - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR( + execute_launch_plan(current_target, + find_nth_element_bucket_kernel, + 1, + num_buckets, + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("find_nth_element_bucket_kernel", size, start); start_timer(); @@ -675,12 +677,12 @@ hipError_t compare_function, ordered_bid); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - copy_buckets_kernel, - num_blocks, - num_threads_per_block, - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(current_target, + copy_buckets_kernel, + num_blocks, + num_threads_per_block, + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("copy_buckets_kernel", size, start); // Copy the results in keys_buffer back to the keys @@ -720,12 +722,12 @@ hipError_t auto block_sort_kernel = [=](auto arch_config) { block_sort_kernel_impl(keys, size, compare_function); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - block_sort_kernel, - 1, - stop_recursion_size, - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(current_target, + block_sort_kernel, + 1, + stop_recursion_size, + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("kernel_block_sort", size, start); return hipSuccess; } diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/device_search.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/device_search.hpp index 50c4d435577..89d8ad6478a 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/device_search.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/device_search.hpp @@ -290,13 +290,15 @@ hipError_t search_impl(void* temporary_storage, using key_type = typename std::iterator_traits::value_type; using output_type = typename std::iterator_traits::value_type; - using config = wrapped_search_config; + using selector = search_config_selector; target_arch target_arch; ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); + gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + const target current_target(target_arch, target_gpu); - const search_config_params params = dispatch_target_arch(target_arch); - + const auto params = get_config(Config{}, current_target); const unsigned int block_size = params.kernel_config.block_size; const unsigned int items_per_thread = params.kernel_config.items_per_thread; const unsigned int items_per_block = block_size * items_per_thread; @@ -349,12 +351,12 @@ hipError_t search_impl(void* temporary_storage, keys_size, compare_function); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - search_shared_kernel, - num_blocks, - block_size, - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(current_target, + search_shared_kernel, + num_blocks, + block_size, + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("search_kernel_shared", size, start); } else @@ -370,12 +372,12 @@ hipError_t search_impl(void* temporary_storage, keys_size, compare_function); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - search_shared_kernel, - num_blocks, - block_size, - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(current_target, + search_shared_kernel, + num_blocks, + block_size, + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("search_kernel_shared", size, start); } } @@ -393,12 +395,12 @@ hipError_t search_impl(void* temporary_storage, keys_size, compare_function); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - search_kernel, - num_blocks, - block_size, - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(current_target, + search_kernel, + num_blocks, + block_size, + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("search_kernel", size, start); } else @@ -414,12 +416,12 @@ hipError_t search_impl(void* temporary_storage, keys_size, compare_function); }; - ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(target_arch, - search_kernel, - num_blocks, - block_size, - 0, - stream)); + ROCPRIM_RETURN_ON_ERROR(execute_launch_plan(current_target, + search_kernel, + num_blocks, + block_size, + 0, + stream)); ROCPRIM_DETAIL_HIP_SYNC_AND_RETURN_ON_ERROR("search_kernel", size, start); } } diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_histogram_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_histogram_config.hpp index c2d5785649f..8dd28c97f52 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_histogram_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_histogram_config.hpp @@ -36,14 +36,14 @@ template struct histogram_config_static_selector { static constexpr auto block_size - = target_config2::params.histogram_config.block_size; + = target_config::params.histogram_config.block_size; }; template struct histogram_global_config_static_selector { static constexpr auto block_size - = target_config2::params.histogram_global_config.block_size; + = target_config::params.histogram_global_config.block_size; }; template diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_memcpy_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_memcpy_config.hpp index 75a0a1dec5d..6586e1fa197 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_memcpy_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_memcpy_config.hpp @@ -43,14 +43,14 @@ namespace detail template struct non_blev_batch_memcpy_config_static_selector { - static constexpr auto block_size = target_config2::params + static constexpr auto block_size = target_config::params .non_blev_batch_memcpy_kernel_config.block_size; }; template struct blev_batch_memcpy_config_static_selector { - static constexpr auto block_size = target_config2::params + static constexpr auto block_size = target_config::params .blev_batch_memcpy_kernel_config.block_size; }; diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_merge.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_merge.hpp index 7765ee05d93..4b41bf68fb3 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_merge.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_merge.hpp @@ -54,7 +54,7 @@ inline size_t get_merge_vsmem_size_per_block(detail::target t) { if(target{candidate} == most_common_config(t)) { - using ArchConfig = target_config2; + using ArchConfig = target_config; using merge_kernel_impl_t = merge_kernel_impl_; using merge_vsmem_helper_t = detail::vsmem_helper_impl; diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_merge_sort.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_merge_sort.hpp index 9014faa23e5..73759a712f2 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_merge_sort.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_merge_sort.hpp @@ -569,8 +569,8 @@ template; - using BMArchConfig = target_config2; + using BSArchConfig = target_config; + using BMArchConfig = target_config; static constexpr auto bs_params = BSArchConfig::params; static constexpr auto bm_params = BMArchConfig::params; @@ -643,8 +643,8 @@ template inline size_t merge_sort_vsmem_size_for_target(size_t size) { - using BSArchConfig = target_config2; - using BMArchConfig = target_config2; + using BSArchConfig = target_config; + using BMArchConfig = target_config; static constexpr auto bs_params = BSArchConfig::params; static constexpr auto bm_params = BMArchConfig::params; diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_merge_sort_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_merge_sort_config.hpp index 639e2d111a3..4bf89d0faea 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_merge_sort_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_merge_sort_config.hpp @@ -90,13 +90,13 @@ template struct merge_oddeven_config_static_selector { static constexpr auto block_size - = target_config2::params.merge_oddeven_config.block_size; + = target_config::params.merge_oddeven_config.block_size; }; template struct merge_mergepath_partition_config_static_selector { - static constexpr auto block_size = target_config2::params + static constexpr auto block_size = target_config::params .merge_mergepath_partition_config.block_size; }; @@ -104,7 +104,7 @@ template struct merge_mergepath_config_static_selector { static constexpr auto block_size - = target_config2::params.merge_mergepath_config.block_size; + = target_config::params.merge_mergepath_config.block_size; }; template diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_nth_element.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_nth_element.hpp index d0bdc770756..25041a9fc6c 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_nth_element.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_nth_element.hpp @@ -58,7 +58,7 @@ hipError_t = nullptr) { using key_type = typename std::iterator_traits::value_type; - using config = wrapped_nth_element_config; + using selector = nth_element_config_selector; bool use_atomic_block_id; ROCPRIM_RETURN_ON_ERROR(check_if_using_atomic_block_id(stream, use_atomic_block_id)); @@ -70,17 +70,15 @@ hipError_t [&](auto use_atomic_block_id) { target_arch target_arch; - hipError_t result = host_target_arch(stream, target_arch); - if(result != hipSuccess) - { - return result; - } - const nth_element_config_params params - = dispatch_target_arch(target_arch); - - constexpr unsigned int num_partitions = 3; - const unsigned int num_buckets = params.number_of_buckets; - const unsigned int num_splitters = num_buckets - 1; + ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); + gpu target_gpu; + ROCPRIM_RETURN_ON_ERROR(host_target_gpu(stream, target_gpu)); + const target current_target(target_arch, target_gpu); + + const auto params = get_config(Config{}, current_target); + constexpr unsigned int num_partitions = 3; + const unsigned int num_buckets = params.number_of_buckets; + const unsigned int num_splitters = num_buckets - 1; const unsigned int stop_recursion_size = params.stop_recursion_size; const unsigned int num_items_per_threads = params.kernel_config.items_per_thread; const unsigned int num_threads_per_block = params.kernel_config.block_size; @@ -163,24 +161,24 @@ hipError_t auto ordered_bid = ordered_bid_type::create(ordered_bid_storage); - return nth_element_keys_impl(target_arch, - keys, - keys_buffer, - tree, - nth, - size, - buckets, - equality_buckets, - lookback_states, - num_buckets, - stop_recursion_size, - num_threads_per_block, - num_items_per_threads, - nth_element_data, - compare_function, - stream, - debug_synchronous, - ordered_bid); + return nth_element_keys_impl(current_target, + keys, + keys_buffer, + tree, + nth, + size, + buckets, + equality_buckets, + lookback_states, + num_buckets, + stop_recursion_size, + num_threads_per_block, + num_items_per_threads, + nth_element_data, + compare_function, + stream, + debug_synchronous, + ordered_bid); }, use_atomic_block_id_variant)); return hipSuccess; diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_nth_element_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_nth_element_config.hpp index 3d18bd7e146..c8bb024b7e1 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_nth_element_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_nth_element_config.hpp @@ -35,41 +35,24 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -// generic struct that instantiates custom configurations -template -struct wrapped_nth_element_config +template +struct nth_element_config_selector { - template - struct architecture_config - { - static constexpr nth_element_config_params params = NthElementConfig{}; - }; + // Targets can not be fully empty. + using targets + = comp_targets>; + using param_type = nth_element_config_params; + + param_type params; + + template + constexpr nth_element_config_selector(Target) + : params(param_type{ + 64, 64, block_radix_rank_algorithm::match, kernel_config_params{512, 8} + }) + {} }; -// specialized for rocprim::default_config, which instantiates the default_nth_element_config -template -struct wrapped_nth_element_config -{ - template - struct architecture_config - { - static constexpr nth_element_config_params params - = {64, 64, block_radix_rank_algorithm::match, kernel_config<512, 8>()}; - }; -}; - -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template -template -constexpr nth_element_config_params - wrapped_nth_element_config::architecture_config::params; - -template -template -constexpr nth_element_config_params - wrapped_nth_element_config::architecture_config::params; -#endif // DOXYGEN_SHOULD_SKIP_THIS - } // namespace detail END_ROCPRIM_NAMESPACE diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_partition.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_partition.hpp index 513a29c4564..d3b384fbc49 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_partition.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_partition.hpp @@ -72,7 +72,7 @@ inline size_t get_partition_vsmem_size_per_block(detail::target t) { if(target{candidate} == most_common_config(t)) { - using ArchConfig = target_config2; + using ArchConfig = target_config; using partition_kernel_impl_t = partition_kernel_impl_ struct radix_sort_onesweep_histogram_config_static_selector { - static constexpr auto block_size = target_config2::params + static constexpr auto block_size = target_config::params .radix_sort_onesweep_config_params::histogram.block_size; }; template struct radix_sort_onesweep_sort_config_static_selector { - static constexpr auto block_size = target_config2::params + static constexpr auto block_size = target_config::params .radix_sort_onesweep_config_params::sort.block_size; }; @@ -111,7 +111,7 @@ struct radix_sort_block_sort_config_selector template struct radix_sort_block_sort_config_static_selector { - static constexpr auto block_size = target_config2::params.block_size; + static constexpr auto block_size = target_config::params.block_size; }; } // namespace detail diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_search_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_search_config.hpp index c6b23ea42ef..f50704d30eb 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_search_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_search_config.hpp @@ -33,40 +33,24 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -// generic struct that instantiates custom configurations -template -struct wrapped_search_config +template +struct search_config_selector { - template - struct architecture_config - { - static constexpr search_config_params params = Config{}; - }; + // Targets can not be fully empty. + using targets + = comp_targets>; + using param_type = search_config_params; + + param_type params; + + template + constexpr search_config_selector(Target) + : params(param_type{ + 2048, kernel_config_params{256, 4} + }) + {} }; -// specialized for rocprim::default_config, which instantiates the default_search_config -template -struct wrapped_search_config -{ - template - struct architecture_config - { - static constexpr search_config_params params = {2048, kernel_config<256, 4>()}; - }; -}; - -#ifndef DOXYGEN_SHOULD_SKIP_THIS -template -template -constexpr search_config_params - wrapped_search_config::architecture_config::params; - -template -template -constexpr search_config_params - wrapped_search_config::architecture_config::params; -#endif // DOXYGEN_SHOULD_SKIP_THIS - } // namespace detail END_ROCPRIM_NAMESPACE diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_segmented_radix_sort_config.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_segmented_radix_sort_config.hpp index e267ae58dca..26aac0a06de 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_segmented_radix_sort_config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_segmented_radix_sort_config.hpp @@ -75,14 +75,14 @@ template struct segmented_radix_sort_warp_sort_small_config_static_selector { static constexpr auto block_size - = target_config2::params.warp_sort_config.block_size_small; + = target_config::params.warp_sort_config.block_size_small; }; template struct segmented_radix_sort_warp_sort_medium_config_static_selector { static constexpr auto block_size - = target_config2::params.warp_sort_config.block_size_medium; + = target_config::params.warp_sort_config.block_size_medium; }; } // end namespace detail From b3639b60c48fb8f0fe730f85f2ae3d0835789c1f Mon Sep 17 00:00:00 2001 From: Nick Breed Date: Tue, 25 Nov 2025 07:20:09 +0000 Subject: [PATCH 12/26] Resolve "Update autotune create_optimization script for new config system" --- .../benchmark_device_histogram.parallel.hpp | 25 +++ .../device/detail/device_config_helper.hpp | 12 +- .../scripts/autotune/create_optimization.py | 170 ++++++++++++++---- .../adjacent_difference_config_template | 21 ++- ...djacent_difference_inplace_config_template | 21 ++- .../templates/adjacent_find_config_template | 21 ++- .../templates/binary_search_config_template | 21 ++- .../autotune/templates/config_template | 31 +++- .../templates/find_first_of_config_template | 21 ++- .../templates/histogram_config_template | 22 ++- .../templates/lower_bound_config_template | 21 ++- .../autotune/templates/merge_config_template | 26 +-- .../mergesort_block_merge_config_template | 20 ++- .../mergesort_block_sort_config_template | 20 ++- .../templates/partition_flag_config_template | 21 ++- .../partition_predicate_config_template | 21 ++- .../partition_three_way_config_template | 21 ++- .../partition_two_way_flag_config_template | 21 ++- ...artition_two_way_predicate_config_template | 21 ++- .../radixsort_block_sort_config_template | 20 ++- .../radixsort_onesweep_config_template | 25 +-- .../templates/reduce_by_key_config_template | 31 ++-- .../autotune/templates/reduce_config_template | 20 ++- .../run_length_encode_config_template | 31 ++-- ...th_encode_non_trivial_runs_config_template | 30 ++-- .../autotune/templates/scan_config_template | 22 ++- .../templates/scanbykey_config_template | 23 ++- .../templates/search_n_config_template | 20 ++- .../segmented_radix_sort_config_template | 41 +++-- .../segmented_reduce_config_template | 20 ++- .../templates/select_flag_config_template | 21 ++- .../select_predicate_config_template | 21 ++- .../select_predicated_flag_config_template | 21 ++- .../select_unique_by_key_config_template | 21 ++- .../templates/select_unique_config_template | 21 ++- .../templates/transform_config_template | 21 ++- .../transform_pointer_config_template | 21 ++- .../templates/upper_bound_config_template | 21 ++- 38 files changed, 650 insertions(+), 358 deletions(-) diff --git a/projects/rocprim/benchmark/benchmark_device_histogram.parallel.hpp b/projects/rocprim/benchmark/benchmark_device_histogram.parallel.hpp index 1a88525565e..e3ff908c714 100644 --- a/projects/rocprim/benchmark/benchmark_device_histogram.parallel.hpp +++ b/projects/rocprim/benchmark/benchmark_device_histogram.parallel.hpp @@ -195,6 +195,21 @@ struct device_histogram_benchmark : public benchmark_utils::autotune_interface + ",cfg:" + config_name() + "}"); } + template + void clear_other_caches() + { + ( + [](auto u) + { + using U = decltype(u); + if(!std::is_same_v) + { + input_cache::instance().clear(); + } + }(Args{}), + ...); + } + void run(benchmark_utils::state&& state) override { const auto& stream = state.stream; @@ -303,6 +318,16 @@ struct device_histogram_benchmark : public benchmark_utils::autotune_interface { HIP_CHECK(hipFree(d_histogram[channel])); } + + // Clear caches for other types that are either empty or already done. + clear_other_caches(); } }; diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/device_config_helper.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/device_config_helper.hpp index fdfb1403ea7..0059f80eabf 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/device_config_helper.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/device_config_helper.hpp @@ -446,20 +446,20 @@ struct warp_sort_config_params /// \brief Allow the partitioning of batches by size for processing via size-optimized kernels. bool partitioning_allowed = false; /// \brief The number of threads in the logical warp in the small segment processing kernel. - unsigned int logical_warp_size_small = 0; + unsigned int logical_warp_size_small = 1; /// \brief The number of items processed by a thread in the small segment processing kernel. - unsigned int items_per_thread_small = 0; + unsigned int items_per_thread_small = 1; /// \brief The number of threads per block in the small segment processing kernel. - unsigned int block_size_small = 0; + unsigned int block_size_small = 1; /// \brief If the number of segments is at least \p partitioning_threshold, then the segments are partitioned into /// small and large segment groups, and each group is handled by a different, specialized kernel. unsigned int partitioning_threshold = 0; /// \brief The number of threads in the logical warp in the medium segment processing kernel. - unsigned int logical_warp_size_medium = 0; + unsigned int logical_warp_size_medium = 1; /// \brief The number of items processed by a thread in the medium segment processing kernel. - unsigned int items_per_thread_medium = 0; + unsigned int items_per_thread_medium = 1; /// \brief The number of threads per block in the medium segment processing kernel. - unsigned int block_size_medium = 0; + unsigned int block_size_medium = 1; }; struct segmented_radix_sort_config_params diff --git a/projects/rocprim/scripts/autotune/create_optimization.py b/projects/rocprim/scripts/autotune/create_optimization.py index 95b58e6dbb1..6d00a63e74a 100755 --- a/projects/rocprim/scripts/autotune/create_optimization.py +++ b/projects/rocprim/scripts/autotune/create_optimization.py @@ -42,6 +42,36 @@ from jinja2 import Environment, PackageLoader, select_autoescape TARGET_ARCHITECTURES = ['gfx803', 'gfx900', 'gfx906', 'gfx908', 'gfx90a', 'gfx942', 'gfx1030', 'gfx1100', 'gfx1102', 'gfx1201'] +TARGET_GPUS_DICT = { + "MI350X": "mi350x", + "MI325X": "mi325x", + "MI308X": "mi308x", + "MI300A": "mi300a", + "MI300X": "mi300x", + "MI210": "mi210", + "MI100": "mi100", + "RX 9060": "rx9060", + "RX 9070": "rx9070", + "V620": "v620", + "RX 7900": "rx7900", + "RX 6900": "rx6900" +} +@dataclass(frozen=True, eq=True) +class Target: + """ + Data class describing the target of the benchmark + """ + gen: str = "gen::unknown" + arch: str = "target_arch::unknown" + gpu: str = "gpu::generic" + rep: str = "rep::amdgcn" + + def __iter__(self): + yield from (self.gen, self.arch, self.gpu, self.rep) + + def as_str(self) -> str: + return ", ".join(str(v) for v in self) + # C++ typename used for optional types EMPTY_TYPENAME = "empty_type" @@ -135,24 +165,28 @@ def translate_settings_to_cpp_metaprogramming( setting_list.append(f"(!std::is_same<{typename}, rocprim::{EMPTY_TYPENAME}>::value)") for name, value in const_configuration.items(): setting_list.append(f"({name} == {value})") - return "std::enable_if_t<(" + " && ".join(setting_list) + ")>" + return "if constexpr(" + " && ".join(setting_list) + ")" -class BenchmarksOfArchitecture: +class BenchmarksOfTarget: """ Stores the benchmark results for a specific architecture and algorithm. """ - def __init__(self, arch_name: str, config_selection_params, fallback_entries: List[FallbackCase], config_get_best, algorithm_name): - self.config_selection_params = config_selection_params - self.fallback_entries: List[FallbackCase] = fallback_entries - self.arch_name: str = arch_name - self.config_get_best: Callable[[Dict], Dict[str, str]] = config_get_best - self.algorithm_name: str = algorithm_name - # Dictionary storing the benchmarks - # Key is an instantiation of the configuration selection types - # Value is a list of all benchmark runs corresponding to that instantiation, - # these benchmarks in this list vary in the actual configuration used to run the benchmark - self.benchmarks = defaultdict(list) + def __init__( + self, + target: Target, + config_selection_params=None, + fallback_entries: Optional[List["FallbackCase"]] = None, + config_get_best: Optional[Callable[[Dict], Dict[str, str]]] = None, + algorithm_name: Optional[str] = None, + ): + self.__target: Target = target + self.config_selection_params = config_selection_params + self.fallback_entries: List["FallbackCase"] = fallback_entries or [] + self.config_get_best = config_get_best + self.algorithm_name = algorithm_name + self.algorithm_name_short = algorithm_name.replace("device_", "") if algorithm_name != None else None + self.benchmarks = defaultdict(list) def __get_instance_key(self, instanced_types): """ @@ -177,8 +211,11 @@ def add_measurement(self, benchmark_data: Dict[str, str]): self.benchmarks[instance_key].append(benchmark_data) @property - def name(self) -> str: - return self.arch_name + def target(self) -> Target: + return self.__target + + def get_enable_if(self, config_params) -> str: + return f'std::enable_if_t>::value, {config_params}>' def __get_best_benchmark(self, instance_key) -> Dict[str, str]: """ @@ -348,31 +385,38 @@ class Algorithm: """ def __init__(self, fallback_entries: List[FallbackCase], config_get_best = default_config_get_best): - self.architectures: Dict[str, BenchmarksOfArchitecture] = {} + self.targets: Dict[Target, BenchmarksOfTarget] = {} self.fallback_entries: List[FallbackCase] = fallback_entries self.config_get_best = config_get_best - def add_measurement(self, single_benchmark_data: Dict[str, str], architecture: str): + def __get_fallback_target(self) -> BenchmarksOfTarget: + """ + Returns M100 or first gpu in the list + """ + for target in self.targets: + if target.gpu == "mi100": + return self.targets[target] + return next(iter(self.targets.values())) + + + def add_measurement(self, single_benchmark_data: Dict[str, str], target: Target): """ Adds a single benchmark execution for a given architecture """ - if architecture not in self.architectures: - self.architectures[architecture] = BenchmarksOfArchitecture(architecture, self.config_selection_params, + if target not in self.targets: + self.targets[target] = BenchmarksOfTarget(target, self.config_selection_params, self.fallback_entries, self.config_get_best, self.algorithm_name) - self.architectures[architecture].add_measurement(single_benchmark_data) + self.targets[target].add_measurement(single_benchmark_data) def create_config_file_content(self) -> str: """ Generate the content of the configuration file, including license and header guards, based on general template file. """ - if 'target_arch::gfx908' in self.architectures: - self.architectures['target_arch::unknown'] = copy.deepcopy(self.architectures['target_arch::gfx908']) - self.architectures['target_arch::unknown'].arch_name = 'target_arch::unknown' - algorithm_template = env.get_template(self.cpp_configuration_template_name) - rendered_template = algorithm_template.render(all_architectures=self.architectures.values()) + fallback_target = self.__get_fallback_target() + rendered_template = algorithm_template.render(all_targets=self.targets.values(), fallback_target=fallback_target, unknown_target= BenchmarksOfTarget(Target())) return rendered_template @@ -808,11 +852,70 @@ def __get_target_architecture_from_context(self, benchmark_run): Uses the benchmark run context embedded into the benchmark output json to retrieve the targeted architecture """ - name_from_context = benchmark_run['context']['hdp_gcn_arch_name'].split(":")[0] - if name_from_context in TARGET_ARCHITECTURES: - return f'target_arch::{name_from_context}' + arch_from_context = benchmark_run['context']['hdp_gcn_arch_name'].split(":")[0] + if arch_from_context in TARGET_ARCHITECTURES: + return f'target_arch::{arch_from_context}' else: - raise RuntimeError(f"ERROR: unknown hdp_gcn_arch_name: {name_from_context}") + raise RuntimeError(f"ERROR: unknown hdp_gcn_arch_name: {arch_from_context}") + + def __get_target_gpu_from_context(self, benchmark_run): + """ + Uses the benchmark run context embedded into the benchmark output json to retrieve the targeted gpu + """ + + gpu_from_context = benchmark_run['context']['hdp_name'] + + ret = "gpu::generic" + + for gpu in TARGET_GPUS_DICT.keys(): + if gpu in gpu_from_context: + ret = f'gpu::{TARGET_GPUS_DICT[gpu]}' + + if ret == "gpu::generic": + print("WARNING: Could find gpu in defined gpus will use gpu::generic", file=sys.stderr, flush=True) + return ret + + def __get_target_rep_from_context(self, benchmark_run): + """ + Uses the benchmark run context embedded into the benchmark output json to retrieve the targeted rep + TODO The data is not yet inbedded in the benchmark output json + """ + return "rep::amdgcn" + + def __get_gen_from_architecture(self, arch): + match arch: + case "target_arch::gfx803": + return "gen::gcn3" + case "target_arch::gfx900" | "target_arch::gfx906": + return "gen::gcn5" + case "target_arch::gfx908": + return "gen::cdna1" + case "target_arch::gfx90a": + return "gen::cdna2" + case "target_arch::gfx942": + return "gen::cdna3" + case "target_arch::gfx950": + return "gen::cdna4" + case "target_arch::gfx1030": + return "gen::rdna2" + case "target_arch::gfx1100" | "target_arch::gfx1102": + return "gen::rdna3" + case "target_arch::gfx1200" | "target_arch::gfx1201": + return "gen::rdna4" + case _: + return "gen::unknown" + + def __get_target(self, benchmark_run) -> Target: + """ + Uses the benchmark run context embedded into the benchmark output json to retrieve the target + """ + arch = self.__get_target_architecture_from_context(benchmark_run) + gpu = self.__get_target_gpu_from_context(benchmark_run) + rep = self.__get_target_rep_from_context(benchmark_run) + gen = self.__get_gen_from_architecture(arch) + + return Target(gen, arch, gpu, rep) + def __get_single_benchmark(self, single_benchmark): """ @@ -828,7 +931,7 @@ def __get_single_benchmark(self, single_benchmark): raise RuntimeError(f"ERROR: cannot parse JSON from: \"{single_benchmark['name']}\"") return dict(single_benchmark, **tokenized_name) - def __add_benchmark_to_algorithm(self, single_benchmark, arch): + def __add_benchmark_to_algorithm(self, single_benchmark, target): """ Adds a single_benchmark execution of a given Algorithm on a given architecture, to the Algorithm object @@ -839,7 +942,7 @@ def __add_benchmark_to_algorithm(self, single_benchmark, arch): algorithm_name += "_" + single_benchmark['subalgo'] if algorithm_name not in self.algorithms: self.algorithms[algorithm_name] = create_algorithm(algorithm_name, self.fallback_entries) - self.algorithms[algorithm_name].add_measurement(single_benchmark, arch) + self.algorithms[algorithm_name].add_measurement(single_benchmark, target) def add_run(self, benchmark_run_file_path: str): """ @@ -852,10 +955,11 @@ def add_run(self, benchmark_run_file_path: str): try: print(f'INFO: Processing "{benchmark_run_file_path}"') - arch = self.__get_target_architecture_from_context(benchmark_run_data) + target = self.__get_target(benchmark_run_data) + for raw_single_benchmark in benchmark_run_data['benchmarks']: single_benchmark = self.__get_single_benchmark(raw_single_benchmark) - self.__add_benchmark_to_algorithm(single_benchmark, arch) + self.__add_benchmark_to_algorithm(single_benchmark, target) print(f'INFO: Successfully processed file "{benchmark_run_file_path}"') except NotSupportedError as error: print(f'WARNING: Could not process file "{benchmark_run_file_path}": {error}', file=sys.stderr, flush=True) diff --git a/projects/rocprim/scripts/autotune/templates/adjacent_difference_config_template b/projects/rocprim/scripts/autotune/templates/adjacent_difference_config_template index 353e86d1fae..02c62d1d546 100644 --- a/projects/rocprim/scripts/autotune/templates/adjacent_difference_config_template +++ b/projects/rocprim/scripts/autotune/templates/adjacent_difference_config_template @@ -5,16 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_ADJACENT_DIFFERENCE_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -adjacent_difference_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +adjacent_difference_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} } } {%- endmacro %} -{% macro general_case() -%} -template -struct default_adjacent_difference_config : default_adjacent_difference_config_base::type -{}; +{% macro config_picker() -%} +template constexpr auto adjacent_difference_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_adjacent_difference_config({{ benchmark_of_architecture.name }}), value_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("adjacent_difference_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return adjacent_difference_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return adjacent_difference_config_picker, value_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/adjacent_difference_inplace_config_template b/projects/rocprim/scripts/autotune/templates/adjacent_difference_inplace_config_template index 3e5d7fa5cdb..6cfb8a99aef 100644 --- a/projects/rocprim/scripts/autotune/templates/adjacent_difference_inplace_config_template +++ b/projects/rocprim/scripts/autotune/templates/adjacent_difference_inplace_config_template @@ -5,16 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_ADJACENT_DIFFERENCE_INPLACE_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -adjacent_difference_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +adjacent_difference_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} } } {%- endmacro %} -{% macro general_case() -%} -template -struct default_adjacent_difference_inplace_config : default_adjacent_difference_config_base::type -{}; +{% macro config_picker() -%} +template constexpr auto adjacent_difference_inplace_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_adjacent_difference_inplace_config({{ benchmark_of_architecture.name }}), value_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("adjacent_difference_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return adjacent_difference_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return adjacent_difference_inplace_config_picker, value_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/adjacent_find_config_template b/projects/rocprim/scripts/autotune/templates/adjacent_find_config_template index 8ef25c83eb1..9bd7e9e1992 100644 --- a/projects/rocprim/scripts/autotune/templates/adjacent_find_config_template +++ b/projects/rocprim/scripts/autotune/templates/adjacent_find_config_template @@ -5,16 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_ADJACENT_FIND_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -adjacent_find_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +adjacent_find_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} } } {%- endmacro %} -{% macro general_case() -%} -template -struct default_adjacent_find_config : default_adjacent_find_config_base::type -{}; +{% macro config_picker() -%} +template constexpr auto adjacent_find_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_adjacent_find_config({{ benchmark_of_architecture.name }}), input_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("adjacent_find_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return adjacent_find_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return adjacent_find_config_picker, input_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/binary_search_config_template b/projects/rocprim/scripts/autotune/templates/binary_search_config_template index 00b19ca8380..d3a00c773ff 100644 --- a/projects/rocprim/scripts/autotune/templates/binary_search_config_template +++ b/projects/rocprim/scripts/autotune/templates/binary_search_config_template @@ -5,16 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_BINARY_SEARCH_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -binary_search_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +transform_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} } } {%- endmacro %} -{% macro general_case() -%} -template -struct default_binary_search_config : default_binary_search_config_base -{}; +{% macro config_picker() -%} +template constexpr auto binary_search_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_binary_search_config({{ benchmark_of_architecture.name }}), value_type, output_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("transform_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return binary_search_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return binary_search_config_picker, value_type, output_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/config_template b/projects/rocprim/scripts/autotune/templates/config_template index 9d44e5cdcbd..1b8112c749b 100644 --- a/projects/rocprim/scripts/autotune/templates/config_template +++ b/projects/rocprim/scripts/autotune/templates/config_template @@ -41,15 +41,34 @@ BEGIN_ROCPRIM_NAMESPACE namespace detail { -{{ general_case() }} - -{% for benchmark_of_architecture in all_architectures %} - {% for based_on_type, fallback_selection_criteria, measurement in benchmark_of_architecture.fallback_types %} -{{ configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) }} -{{ kernel_configuration(measurement) }} +{% set ns = namespace(algorithm_name='') %} +{% for benchmark_of_target in all_targets %} +{% set ns.algorithm_name = benchmark_of_target.algorithm_name_short %} +{{ config_picker() }} -> {{ enable_if(benchmark_of_target) }} +{ + {% for based_on_type, fallback_selection_criteria, measurement in benchmark_of_target.fallback_types %} +// Based on {{ based_on_type }} +{{ fallback_selection_criteria }} +{ return {{ kernel_configuration(measurement) }}; } {% endfor %} +// Default case if none of the conditions match +{{ default_case() }} +} + +{% endfor %} + +{{ config_picker() }} -> {{ enable_if(unknown_target) }} +{ +{{fallback_config(fallback_target)}} +} + +// All existing configs +using {{ns.algorithm_name}}_targets = comp_targets< +{% for benchmark_of_target in all_targets %} +comp_target<{{benchmark_of_target.target.as_str()}}>, {% endfor %} +comp_target<{{unknown_target.target.as_str()}}>>; } // end namespace detail diff --git a/projects/rocprim/scripts/autotune/templates/find_first_of_config_template b/projects/rocprim/scripts/autotune/templates/find_first_of_config_template index 84e7fb476aa..9f3aa58b36d 100644 --- a/projects/rocprim/scripts/autotune/templates/find_first_of_config_template +++ b/projects/rocprim/scripts/autotune/templates/find_first_of_config_template @@ -5,16 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_FIND_FIRST_OF_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -find_first_of_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +find_first_of_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} } } {%- endmacro %} -{% macro general_case() -%} -template -struct default_find_first_of_config : default_find_first_of_config_base::type -{}; +{% macro config_picker() -%} +template constexpr auto find_first_of_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_find_first_of_config({{ benchmark_of_architecture.name }}), value_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("find_first_of_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return find_first_of_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return find_first_of_config_picker, value_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/histogram_config_template b/projects/rocprim/scripts/autotune/templates/histogram_config_template index 810d2b84fa6..0f82adb7dcc 100644 --- a/projects/rocprim/scripts/autotune/templates/histogram_config_template +++ b/projects/rocprim/scripts/autotune/templates/histogram_config_template @@ -5,17 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_HISTOGRAM_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -histogram_config, {{ measurement['cfg']['max_grid_size'] }}, {{ measurement['cfg']['shared_impl_max_bins'] }}, {{ measurement['cfg']['shared_impl_histograms'] }}, kernel_config<{{ measurement['cfg']['global_hist_bs'] }}, {{ measurement['cfg']['global_hist_ipt'] }}>> { }; +histogram_config_params{kernel_config_params{ {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} }, {{ measurement['cfg']['max_grid_size'] }}, {{ measurement['cfg']['shared_impl_max_bins'] }}, {{ measurement['cfg']['shared_impl_histograms'] }}, kernel_config_params{ {{ measurement['cfg']['global_hist_bs'] }}, {{ measurement['cfg']['global_hist_ipt'] }} } } {%- endmacro %} -{% macro general_case() -%} -template -struct default_histogram_config : -default_histogram_config_base::type { }; +{% macro config_picker() -%} +template constexpr auto histogram_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template -struct default_histogram_config({{ benchmark_of_architecture.name }}), value_type, channels, active_channels, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("histogram_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return histogram_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return histogram_config_picker, value_type, channels, active_channels>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/lower_bound_config_template b/projects/rocprim/scripts/autotune/templates/lower_bound_config_template index f2d49ff118a..445fbfbcc9b 100644 --- a/projects/rocprim/scripts/autotune/templates/lower_bound_config_template +++ b/projects/rocprim/scripts/autotune/templates/lower_bound_config_template @@ -5,16 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_LOWER_BOUND_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -lower_bound_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +transform_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} } } {%- endmacro %} -{% macro general_case() -%} -template -struct default_lower_bound_config : default_binary_search_config_base -{}; +{% macro config_picker() -%} +template constexpr auto lower_bound_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_lower_bound_config({{ benchmark_of_architecture.name }}), value_type, output_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("transform_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return binary_search_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return lower_bound_config_picker, value_type, output_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/merge_config_template b/projects/rocprim/scripts/autotune/templates/merge_config_template index e89592f68ed..7234cca5260 100644 --- a/projects/rocprim/scripts/autotune/templates/merge_config_template +++ b/projects/rocprim/scripts/autotune/templates/merge_config_template @@ -5,21 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_MERGE_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -merge_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +merge_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} } } {%- endmacro %} -{% macro general_case() -%} -template -struct default_merge_config : default_merge_config_base::type -{}; +{% macro config_picker() -%} +template constexpr auto merge_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template -struct default_merge_config< - static_cast({{ benchmark_of_architecture.name }}), - key_type, - value_type, - {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("merge_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return merge_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return merge_config_picker, key_type, value_type>(); {%- endmacro %} \ No newline at end of file diff --git a/projects/rocprim/scripts/autotune/templates/mergesort_block_merge_config_template b/projects/rocprim/scripts/autotune/templates/mergesort_block_merge_config_template index 06e4236562f..f0c39409606 100644 --- a/projects/rocprim/scripts/autotune/templates/mergesort_block_merge_config_template +++ b/projects/rocprim/scripts/autotune/templates/mergesort_block_merge_config_template @@ -5,15 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_MERGE_SORT_BLOCK_MERGE_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -merge_sort_block_merge_config<256, 1, (1 << 17) + 70000, {{ measurement['cfg']['mergepath_partition_bs'] }}, {{ measurement['cfg']['mergepath_bs'] }}, {{ measurement['cfg']['mergepath_ipt'] }}> { }; +merge_sort_block_merge_config_params{ { 256, 1, (1 << 17) + 70000 }, { {{ measurement['cfg']['mergepath_partition_bs'] }}, 1}, { {{ measurement['cfg']['mergepath_bs'] }}, {{ measurement['cfg']['mergepath_ipt'] }} } } {%- endmacro %} -{% macro general_case() -%} -template struct default_merge_sort_block_merge_config : -merge_sort_block_merge_config_base::type {}; +{% macro config_picker() -%} +template constexpr auto merge_sort_block_merge_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_merge_sort_block_merge_config({{ benchmark_of_architecture.name }}), key_type, value_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("merge_sort_block_merge_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return merge_sort_block_merge_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return merge_sort_block_merge_config_picker, key_type, value_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/mergesort_block_sort_config_template b/projects/rocprim/scripts/autotune/templates/mergesort_block_sort_config_template index f3931f2c780..84666b0cf5a 100644 --- a/projects/rocprim/scripts/autotune/templates/mergesort_block_sort_config_template +++ b/projects/rocprim/scripts/autotune/templates/mergesort_block_sort_config_template @@ -5,19 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_MERGE_SORT_BLOCK_SORT_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -merge_sort_block_sort_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +merge_sort_block_sort_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} } } {%- endmacro %} -{% macro general_case() -%} -template struct default_merge_sort_block_sort_config : -merge_sort_block_sort_config_base::type {}; +{% macro config_picker() -%} +template constexpr auto merge_sort_block_sort_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_merge_sort_block_sort_config({{ benchmark_of_architecture.name }}), key_type, value_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("merge_sort_block_sort_config_params") }} {%- endmacro %} +{% macro default_case() -%} +return merge_sort_block_sort_config_params_base(); +{%- endmacro %} - - +{% macro fallback_config(fallback_target) -%} + return merge_sort_block_sort_config_picker, key_type, value_type>(); +{%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/partition_flag_config_template b/projects/rocprim/scripts/autotune/templates/partition_flag_config_template index 4495e5de7a8..e67f0541f04 100644 --- a/projects/rocprim/scripts/autotune/templates/partition_flag_config_template +++ b/projects/rocprim/scripts/autotune/templates/partition_flag_config_template @@ -5,16 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_PARTITION_FLAG_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -select_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +partition_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} } } {%- endmacro %} -{% macro general_case() -%} -template -struct default_partition_flag_config : default_partition_config_base::type -{}; +{% macro config_picker() -%} +template constexpr auto partition_flag_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_partition_flag_config({{ benchmark_of_architecture.name }}), data_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("partition_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return partition_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return partition_flag_config_picker, data_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/partition_predicate_config_template b/projects/rocprim/scripts/autotune/templates/partition_predicate_config_template index 021a343ae3d..95892f528a8 100644 --- a/projects/rocprim/scripts/autotune/templates/partition_predicate_config_template +++ b/projects/rocprim/scripts/autotune/templates/partition_predicate_config_template @@ -5,16 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_PARTITION_PREDICATE_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -select_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +partition_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} } } {%- endmacro %} -{% macro general_case() -%} -template -struct default_partition_predicate_config : default_partition_config_base::type -{}; +{% macro config_picker() -%} +template constexpr auto partition_predicate_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_partition_predicate_config({{ benchmark_of_architecture.name }}), data_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("partition_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return partition_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return partition_predicate_config_picker, data_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/partition_three_way_config_template b/projects/rocprim/scripts/autotune/templates/partition_three_way_config_template index 01f06bc8b1f..7929fdfb30c 100644 --- a/projects/rocprim/scripts/autotune/templates/partition_three_way_config_template +++ b/projects/rocprim/scripts/autotune/templates/partition_three_way_config_template @@ -5,16 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_PARTITION_THREE_WAY_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -select_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +partition_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} } } {%- endmacro %} -{% macro general_case() -%} -template -struct default_partition_three_way_config : default_partition_config_base::type -{}; +{% macro config_picker() -%} +template constexpr auto partition_three_way_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_partition_three_way_config({{ benchmark_of_architecture.name }}), data_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("partition_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return partition_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return partition_three_way_config_picker, data_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/partition_two_way_flag_config_template b/projects/rocprim/scripts/autotune/templates/partition_two_way_flag_config_template index 16826143e61..0eca0b9da8e 100644 --- a/projects/rocprim/scripts/autotune/templates/partition_two_way_flag_config_template +++ b/projects/rocprim/scripts/autotune/templates/partition_two_way_flag_config_template @@ -5,16 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_PARTITION_TWO_WAY_FLAG_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -select_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +partition_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} } } {%- endmacro %} -{% macro general_case() -%} -template -struct default_partition_two_way_flag_config : default_partition_config_base::type -{}; +{% macro config_picker() -%} +template constexpr auto partition_two_way_flag_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_partition_two_way_flag_config({{ benchmark_of_architecture.name }}), data_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("partition_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return partition_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return partition_two_way_flag_config_picker, data_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/partition_two_way_predicate_config_template b/projects/rocprim/scripts/autotune/templates/partition_two_way_predicate_config_template index ac11f1459de..a4c9a927745 100644 --- a/projects/rocprim/scripts/autotune/templates/partition_two_way_predicate_config_template +++ b/projects/rocprim/scripts/autotune/templates/partition_two_way_predicate_config_template @@ -5,16 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_PARTITION_TWO_WAY_PREDICATE_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -select_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +partition_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} } } {%- endmacro %} -{% macro general_case() -%} -template -struct default_partition_two_way_predicate_config : default_partition_config_base::type -{}; +{% macro config_picker() -%} +template constexpr auto partition_two_way_predicate_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_partition_two_way_predicate_config({{ benchmark_of_architecture.name }}), data_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("partition_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return partition_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return partition_two_way_predicate_config_picker, data_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/radixsort_block_sort_config_template b/projects/rocprim/scripts/autotune/templates/radixsort_block_sort_config_template index b27a6aa30d8..5f664fee5f7 100644 --- a/projects/rocprim/scripts/autotune/templates/radixsort_block_sort_config_template +++ b/projects/rocprim/scripts/autotune/templates/radixsort_block_sort_config_template @@ -5,19 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_RADIX_SORT_BLOCK_SORT_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -kernel_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +kernel_config_params{ {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} } {%- endmacro %} -{% macro general_case() -%} -template struct default_radix_sort_block_sort_config : -radix_sort_block_sort_config_base::type { }; +{% macro config_picker() -%} +template constexpr auto radix_sort_block_sort_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_radix_sort_block_sort_config({{ benchmark_of_architecture.name }}), key_type, value_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("kernel_config_params") }} {%- endmacro %} - - +{% macro default_case() -%} +return radix_sort_block_sort_config_params_base(); +{%- endmacro %} +{% macro fallback_config(fallback_target) -%} + return radix_sort_block_sort_config_picker, key_type, value_type>(); +{%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/radixsort_onesweep_config_template b/projects/rocprim/scripts/autotune/templates/radixsort_onesweep_config_template index 591df8ef88b..3642304fce9 100644 --- a/projects/rocprim/scripts/autotune/templates/radixsort_onesweep_config_template +++ b/projects/rocprim/scripts/autotune/templates/radixsort_onesweep_config_template @@ -5,20 +5,25 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_RADIX_SORT_ONESWEEP_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -radix_sort_onesweep_config, -kernel_config<{{ measurement['cfg']['sort']['bs'] }}, {{ measurement['cfg']['sort']['ipt'] }}>, {{ measurement['cfg']['bits_per_place'] }}, {{ measurement['cfg']['algorithm'] }}> { }; +radix_sort_onesweep_config_params{kernel_config_params{ {{ measurement['cfg']['histogram']['bs'] }}, {{ measurement['cfg']['histogram']['ipt'] }} }, +kernel_config_params{ {{ measurement['cfg']['sort']['bs'] }}, {{ measurement['cfg']['sort']['ipt'] }} }, {{ measurement['cfg']['bits_per_place'] }}, {{ measurement['cfg']['algorithm'] }} } {%- endmacro %} -{% macro general_case() -%} -template struct default_radix_sort_onesweep_config : -radix_sort_onesweep_config_base::type { }; +{% macro config_picker() -%} +template constexpr auto radix_sort_onesweep_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_radix_sort_onesweep_config({{ benchmark_of_architecture.name }}), key_type, value_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("radix_sort_onesweep_config_params") }} {%- endmacro %} +{% macro default_case() -%} +return radix_sort_onesweep_config_params_base(); +{%- endmacro %} - - +{% macro fallback_config(fallback_target) -%} + return radix_sort_onesweep_config_picker< + comp_target<{{ fallback_target.target.as_str() }}>, + key_type, + value_type>(); +{%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/reduce_by_key_config_template b/projects/rocprim/scripts/autotune/templates/reduce_by_key_config_template index e13ea2e5990..55fd225410f 100644 --- a/projects/rocprim/scripts/autotune/templates/reduce_by_key_config_template +++ b/projects/rocprim/scripts/autotune/templates/reduce_by_key_config_template @@ -5,26 +5,25 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_REDUCE_BY_KEY_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -reduce_by_key_config< - {{ measurement['cfg']['bs'] }}, - {{ measurement['cfg']['ipt'] }}, +reduce_by_key_config_params{ + { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} }, block_load_method::block_load_transpose, block_load_method::block_load_transpose, - block_scan_algorithm::using_warp_scan> { }; + block_scan_algorithm::using_warp_scan } {%- endmacro %} -{% macro general_case() -%} -template -struct default_reduce_by_key_config : default_reduce_by_key_config_base::type -{}; +{% macro config_picker() -%} +template constexpr auto reduce_by_key_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template -struct default_reduce_by_key_config< - static_cast({{ benchmark_of_architecture.name }}), - key_type, - value_type, - {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("reduce_by_key_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return reduce_by_key_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return reduce_by_key_config_picker, key_type, value_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/reduce_config_template b/projects/rocprim/scripts/autotune/templates/reduce_config_template index 3d3924cf39e..c28969387c5 100644 --- a/projects/rocprim/scripts/autotune/templates/reduce_config_template +++ b/projects/rocprim/scripts/autotune/templates/reduce_config_template @@ -5,19 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_REDUCE_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -reduce_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}, ::rocprim::block_reduce_algorithm::{{ measurement['cfg']['method'] }}> { }; +reduce_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} }, ::rocprim::block_reduce_algorithm::{{ measurement['cfg']['method'] }} } {%- endmacro %} -{% macro general_case() -%} -template struct default_reduce_config : -default_reduce_config_base::type { }; +{% macro config_picker() -%} +template constexpr auto reduce_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_reduce_config({{ benchmark_of_architecture.name }}), key_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("reduce_config_params") }} {%- endmacro %} - - +{% macro default_case() -%} +return reduce_config_params_base(); +{%- endmacro %} +{% macro fallback_config(fallback_target) -%} + return reduce_config_picker, key_type>(); +{%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/run_length_encode_config_template b/projects/rocprim/scripts/autotune/templates/run_length_encode_config_template index 17b2003d7da..e443a1f8662 100644 --- a/projects/rocprim/scripts/autotune/templates/run_length_encode_config_template +++ b/projects/rocprim/scripts/autotune/templates/run_length_encode_config_template @@ -5,26 +5,25 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_RUN_LENGTH_ENCODE_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -reduce_by_key_config< - {{ measurement['cfg']['bs'] }}, - {{ measurement['cfg']['ipt'] }}, +reduce_by_key_config_params{ + { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} }, block_load_method::block_load_transpose, block_load_method::block_load_direct, - block_scan_algorithm::using_warp_scan> { }; + block_scan_algorithm::using_warp_scan} {%- endmacro %} -{% macro general_case() -%} -template -struct default_trivial_runs_config : default_reduce_by_key_config_base::type -{}; +{% macro config_picker() -%} +template constexpr auto run_length_encode_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template -struct default_trivial_runs_config< - static_cast({{ benchmark_of_architecture.name }}), - key_type, - value_type, - {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("reduce_by_key_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return reduce_by_key_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return run_length_encode_config_picker, key_type, value_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/run_length_encode_non_trivial_runs_config_template b/projects/rocprim/scripts/autotune/templates/run_length_encode_non_trivial_runs_config_template index 49ce27a781c..e82a0225496 100644 --- a/projects/rocprim/scripts/autotune/templates/run_length_encode_non_trivial_runs_config_template +++ b/projects/rocprim/scripts/autotune/templates/run_length_encode_non_trivial_runs_config_template @@ -5,24 +5,24 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_RUN_LENGTH_ENCODE_NON_TRIVIAL_RUNS_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -non_trivial_runs_config< - {{ measurement['cfg']['bs'] }}, - {{ measurement['cfg']['ipt'] }}, +non_trivial_runs_config_params{ + { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} }, ::rocprim::{{ measurement['cfg']['load_method'] }}, - ::rocprim::block_scan_algorithm::using_warp_scan> { }; + ::rocprim::block_scan_algorithm::using_warp_scan} {%- endmacro %} -{% macro general_case() -%} -template -struct default_non_trivial_runs_config : default_non_trivial_runs_config_base::type -{}; +{% macro config_picker() -%} +template constexpr auto run_length_encode_non_trivial_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template -struct default_non_trivial_runs_config< - static_cast({{ benchmark_of_architecture.name }}), - key_type, - {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("non_trivial_runs_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return non_trivial_runs_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return run_length_encode_non_trivial_config_picker, key_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/scan_config_template b/projects/rocprim/scripts/autotune/templates/scan_config_template index 6de5a83ad95..3602f395b50 100644 --- a/projects/rocprim/scripts/autotune/templates/scan_config_template +++ b/projects/rocprim/scripts/autotune/templates/scan_config_template @@ -5,19 +5,23 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SCAN_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -scan_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}, ::rocprim::block_load_method::block_load_transpose, ::rocprim::block_store_method::block_store_transpose, {{ measurement['cfg']['method'] }}> { }; +scan_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} }, ::rocprim::block_load_method::block_load_transpose, ::rocprim::block_store_method::block_store_transpose, {{ measurement['cfg']['method'] }} } {%- endmacro %} -{% macro general_case() -%} -template struct default_scan_config : -default_scan_config_base::type { }; +{% macro config_picker() -%} +template constexpr auto scan_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_scan_config({{ benchmark_of_architecture.name }}), value_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("scan_config_params") }} {%- endmacro %} - - +{% macro default_case() -%} +return scan_config_params_base(); +{%- endmacro %} +{% macro fallback_config(fallback_target) -%} + return scan_config_picker< + comp_target<{{ fallback_target.target.as_str() }}>, + value_type>(); +{%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/scanbykey_config_template b/projects/rocprim/scripts/autotune/templates/scanbykey_config_template index 72653bb61a4..52138a1b966 100644 --- a/projects/rocprim/scripts/autotune/templates/scanbykey_config_template +++ b/projects/rocprim/scripts/autotune/templates/scanbykey_config_template @@ -5,19 +5,24 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SCAN_BY_KEY_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -scan_by_key_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}, ::rocprim::block_load_method::block_load_transpose, ::rocprim::block_store_method::block_store_transpose, {{ measurement['cfg']['method'] }}> { }; +scan_by_key_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} }, ::rocprim::block_load_method::block_load_transpose, ::rocprim::block_store_method::block_store_transpose, {{ measurement['cfg']['method'] }} } {%- endmacro %} -{% macro general_case() -%} -template struct default_scan_by_key_config : -default_scan_by_key_config_base::type { }; +{% macro config_picker() -%} +template constexpr auto scan_by_key_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_scan_by_key_config({{ benchmark_of_architecture.name }}), key_type, value_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("scan_by_key_config_params") }} {%- endmacro %} - - +{% macro default_case() -%} +return scan_by_key_config_params_base(); +{%- endmacro %} +{% macro fallback_config(fallback_target) -%} + return scan_by_key_config_picker< + comp_target<{{ fallback_target.target.as_str() }}>, + key_type, + value_type>(); +{%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/search_n_config_template b/projects/rocprim/scripts/autotune/templates/search_n_config_template index 7ad5393b1e7..919b5c29b03 100644 --- a/projects/rocprim/scripts/autotune/templates/search_n_config_template +++ b/projects/rocprim/scripts/autotune/templates/search_n_config_template @@ -5,15 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SEARCH_N_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -search_n_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}, {{ measurement['cfg']['threshold'] }}> { }; +search_n_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} }, {{ measurement['cfg']['threshold'] }} } {%- endmacro %} -{% macro general_case() -%} -template struct default_search_n_config : -default_search_n_config_base::type { }; +{% macro config_picker() -%} +template constexpr auto search_n_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_search_n_config({{ benchmark_of_architecture.name }}), data_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("search_n_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return search_n_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return search_n_config_picker, data_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/segmented_radix_sort_config_template b/projects/rocprim/scripts/autotune/templates/segmented_radix_sort_config_template index 437718b9afd..8e625904636 100644 --- a/projects/rocprim/scripts/autotune/templates/segmented_radix_sort_config_template +++ b/projects/rocprim/scripts/autotune/templates/segmented_radix_sort_config_template @@ -5,28 +5,37 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SEGMENTED_RADIX_SORT_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -segmented_radix_sort_config< +segmented_radix_sort_config_params{ {{ measurement['cfg']['rb'] }}, - kernel_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}>, - typename std::conditional< - {{ measurement['cfg']['wsc']['pa'] }}, - WarpSortConfig< + kernel_config_params{ {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} }, +{% if measurement['cfg']['wsc']['pa'] == 1 -%} + warp_sort_config_params{ + {{ measurement['cfg']['wsc']['pa'] }}, {{ measurement['cfg']['wsc']['lwss'] }}, {{ measurement['cfg']['wsc']['ipts'] }}, {{ measurement['cfg']['wsc']['bss'] }}, {{ measurement['cfg']['wsc']['pt'] }}, {{ measurement['cfg']['wsc']['lwsm'] }}, {{ measurement['cfg']['wsc']['iptm'] }}, - {{ measurement['cfg']['wsc']['bsm'] }}>, - DisabledWarpSortConfig - >::type, - {{ measurement['cfg']['eupws'] }} > { }; + {{ measurement['cfg']['wsc']['bsm'] }} }, +{% else -%} + warp_sort_config_params{0}, +{% endif -%} + {{ measurement['cfg']['eupws'] }} } {%- endmacro %} -{% macro general_case() -%} -template -struct default_segmented_radix_sort_config : default_segmented_radix_sort_config_base<6>::type -{}; +{% macro config_picker() -%} +template constexpr auto segmented_radix_sort_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_segmented_radix_sort_config({{ benchmark_of_architecture.name }}), key_type, value_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("segmented_radix_sort_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return segmented_radix_sort_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return segmented_radix_sort_config_picker< + comp_target<{{ fallback_target.target.as_str() }}>, + key_type, + value_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/segmented_reduce_config_template b/projects/rocprim/scripts/autotune/templates/segmented_reduce_config_template index 3b6ed10bcf3..acce73162a5 100644 --- a/projects/rocprim/scripts/autotune/templates/segmented_reduce_config_template +++ b/projects/rocprim/scripts/autotune/templates/segmented_reduce_config_template @@ -5,15 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SEGMENTED_REDUCE_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -reduce_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}, ::rocprim::block_reduce_algorithm::{{ measurement['cfg']['method'] }}> { }; +reduce_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} }, ::rocprim::block_reduce_algorithm::{{ measurement['cfg']['method'] }} } {%- endmacro %} -{% macro general_case() -%} -template struct default_segmented_reduce_config : -default_reduce_config_base::type { }; +{% macro config_picker() -%} +template constexpr auto segmented_reduce_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_segmented_reduce_config({{ benchmark_of_architecture.name }}), key_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("reduce_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return reduce_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return segmented_reduce_config_picker, key_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/select_flag_config_template b/projects/rocprim/scripts/autotune/templates/select_flag_config_template index 36a7055e485..80f8a8aba18 100644 --- a/projects/rocprim/scripts/autotune/templates/select_flag_config_template +++ b/projects/rocprim/scripts/autotune/templates/select_flag_config_template @@ -5,16 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SELECT_FLAG_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -select_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +partition_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} } } {%- endmacro %} -{% macro general_case() -%} -template -struct default_select_flag_config : default_partition_config_base::type -{}; +{% macro config_picker() -%} +template constexpr auto select_flag_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_select_flag_config({{ benchmark_of_architecture.name }}), data_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("partition_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return partition_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return select_flag_config_picker, data_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/select_predicate_config_template b/projects/rocprim/scripts/autotune/templates/select_predicate_config_template index 137ba5ee387..b5450975e32 100644 --- a/projects/rocprim/scripts/autotune/templates/select_predicate_config_template +++ b/projects/rocprim/scripts/autotune/templates/select_predicate_config_template @@ -5,16 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SELECT_PREDICATE_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -select_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +partition_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} } } {%- endmacro %} -{% macro general_case() -%} -template -struct default_select_predicate_config : default_partition_config_base::type -{}; +{% macro config_picker() -%} +template constexpr auto select_predicate_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_select_predicate_config({{ benchmark_of_architecture.name }}), data_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("partition_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return partition_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return select_predicate_config_picker, data_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/select_predicated_flag_config_template b/projects/rocprim/scripts/autotune/templates/select_predicated_flag_config_template index 2ba3565aefd..7907d92b2df 100644 --- a/projects/rocprim/scripts/autotune/templates/select_predicated_flag_config_template +++ b/projects/rocprim/scripts/autotune/templates/select_predicated_flag_config_template @@ -5,16 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SELECT_PREDICATED_FLAG_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -select_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +partition_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} } } {%- endmacro %} -{% macro general_case() -%} -template -struct default_select_predicated_flag_config : default_partition_config_base::type -{}; +{% macro config_picker() -%} +template constexpr auto select_predicated_flag_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_select_predicated_flag_config({{ benchmark_of_architecture.name }}), data_type, flag_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("partition_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return partition_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return select_predicated_flag_config_picker, data_type, flag_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/select_unique_by_key_config_template b/projects/rocprim/scripts/autotune/templates/select_unique_by_key_config_template index 2177f95af1d..200b2de7b2f 100644 --- a/projects/rocprim/scripts/autotune/templates/select_unique_by_key_config_template +++ b/projects/rocprim/scripts/autotune/templates/select_unique_by_key_config_template @@ -5,16 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SELECT_UNIQUE_BY_KEY_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -select_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +partition_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} } } {%- endmacro %} -{% macro general_case() -%} -template -struct default_select_unique_by_key_config : default_partition_config_base::type -{}; +{% macro config_picker() -%} +template constexpr auto select_unique_by_key_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_select_unique_by_key_config({{ benchmark_of_architecture.name }}), key_type, value_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("partition_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return partition_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return select_unique_by_key_config_picker, key_type, value_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/select_unique_config_template b/projects/rocprim/scripts/autotune/templates/select_unique_config_template index 9bb8aec9f34..9c005946dca 100644 --- a/projects/rocprim/scripts/autotune/templates/select_unique_config_template +++ b/projects/rocprim/scripts/autotune/templates/select_unique_config_template @@ -5,16 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_SELECT_UNIQUE_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -select_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +partition_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} } } {%- endmacro %} -{% macro general_case() -%} -template -struct default_select_unique_config : default_partition_config_base::type -{}; +{% macro config_picker() -%} +template constexpr auto select_unique_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_select_unique_config({{ benchmark_of_architecture.name }}), data_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("partition_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return partition_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return select_unique_config_picker, data_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/transform_config_template b/projects/rocprim/scripts/autotune/templates/transform_config_template index bc9e4fe8107..fbbef6af1b5 100644 --- a/projects/rocprim/scripts/autotune/templates/transform_config_template +++ b/projects/rocprim/scripts/autotune/templates/transform_config_template @@ -5,16 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_TRANSFORM_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -transform_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +transform_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} } } {%- endmacro %} -{% macro general_case() -%} -template -struct default_transform_config : default_transform_config_base::type -{}; +{% macro config_picker() -%} +template constexpr auto transform_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_transform_config({{ benchmark_of_architecture.name }}), value_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("transform_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return transform_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return transform_config_picker, value_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/transform_pointer_config_template b/projects/rocprim/scripts/autotune/templates/transform_pointer_config_template index efa1193dff1..46ad6667241 100644 --- a/projects/rocprim/scripts/autotune/templates/transform_pointer_config_template +++ b/projects/rocprim/scripts/autotune/templates/transform_pointer_config_template @@ -5,16 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_TRANSFORM_POINTER_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -transform_pointer_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}, ROCPRIM_GRID_SIZE_LIMIT, ::rocprim::{{ measurement['cfg']['lt'] }}> { }; +transform_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} }, ::rocprim::{{ measurement['cfg']['lt'] }} } {%- endmacro %} -{% macro general_case() -%} -template -struct default_transform_pointer_config : default_transform_pointer_config_base::type -{}; +{% macro config_picker() -%} +template constexpr auto transform_pointer_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_transform_pointer_config({{ benchmark_of_architecture.name }}), value_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("transform_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return transform_pointer_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return transform_pointer_config_picker, value_type>(); {%- endmacro %} diff --git a/projects/rocprim/scripts/autotune/templates/upper_bound_config_template b/projects/rocprim/scripts/autotune/templates/upper_bound_config_template index 05d05c4de98..9bd9c8a83c3 100644 --- a/projects/rocprim/scripts/autotune/templates/upper_bound_config_template +++ b/projects/rocprim/scripts/autotune/templates/upper_bound_config_template @@ -5,16 +5,21 @@ ROCPRIM_DEVICE_DETAIL_CONFIG_DEVICE_UPPER_BOUND_HPP_ {%- endmacro %} {% macro kernel_configuration(measurement) -%} -upper_bound_config<{{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }}> { }; +transform_config_params{ { {{ measurement['cfg']['bs'] }}, {{ measurement['cfg']['ipt'] }} } } {%- endmacro %} -{% macro general_case() -%} -template -struct default_upper_bound_config : default_binary_search_config_base -{}; +{% macro config_picker() -%} +template constexpr auto upper_bound_config_picker() {%- endmacro %} -{% macro configuration_fallback(benchmark_of_architecture, based_on_type, fallback_selection_criteria) -%} -// Based on {{ based_on_type }} -template struct default_upper_bound_config({{ benchmark_of_architecture.name }}), value_type, output_type, {{ fallback_selection_criteria }}> : +{% macro enable_if(benchmark_of_target) -%} +{{ benchmark_of_target.get_enable_if("transform_config_params") }} +{%- endmacro %} + +{% macro default_case() -%} +return binary_search_config_params_base(); +{%- endmacro %} + +{% macro fallback_config(fallback_target) -%} + return upper_bound_config_picker, value_type, output_type>(); {%- endmacro %} From 079fad812e9450fbc22517cb7e58f6fe0cc84118 Mon Sep 17 00:00:00 2001 From: Nick Breed Date: Thu, 27 Nov 2025 11:54:29 +0000 Subject: [PATCH 13/26] Resolve "Update apply_config_improvement script for new configs" --- .../apply_config_improvements.py | 383 +++++++++++++----- 1 file changed, 289 insertions(+), 94 deletions(-) diff --git a/projects/rocprim/scripts/apply_config_improvements/apply_config_improvements.py b/projects/rocprim/scripts/apply_config_improvements/apply_config_improvements.py index 6636d862223..f60f86d9f5c 100644 --- a/projects/rocprim/scripts/apply_config_improvements/apply_config_improvements.py +++ b/projects/rocprim/scripts/apply_config_improvements/apply_config_improvements.py @@ -21,6 +21,7 @@ # THE SOFTWARE. import argparse +from dataclasses import dataclass import json import re import sys @@ -38,6 +39,112 @@ slower_rejected_count = 0 marginal_improvements_rejected_count = 0 +TARGET_GPUS_DICT = { + "MI350X": "mi350x", + "MI325X": "mi325x", + "MI308X": "mi308x", + "MI300A": "mi300a", + "MI300X": "mi300x", + "MI210": "mi210", + "MI100": "mi100", + "RX 9060": "rx9060", + "RX 9070": "rx9070", + "V620": "v620", + "RX 7900": "rx7900", + "RX 6900": "rx6900" +} + +def get_gen_from_architecture(arch): + match arch: + case "target_arch::gfx803": + return "gen::gcn3" + case "target_arch::gfx900" | "target_arch::gfx906": + return "gen::gcn5" + case "target_arch::gfx908": + return "gen::cdna1" + case "target_arch::gfx90a": + return "gen::cdna2" + case "target_arch::gfx942": + return "gen::cdna3" + case "target_arch::gfx950": + return "gen::cdna4" + case "target_arch::gfx1030": + return "gen::rdna2" + case "target_arch::gfx1100" | "target_arch::gfx1102": + return "gen::rdna3" + case "target_arch::gfx1200" | "target_arch::gfx1201": + return "gen::rdna4" + case _: + return "gen::unknown" + + +def get_target_gpu_from_context(context): + """ + Uses the benchmark run context embedded into the benchmark output json to retrieve the targeted gpu + """ + + gpu_from_context = context['hdp_name'] + + ret = "gpu::generic" + + for gpu in TARGET_GPUS_DICT.keys(): + if gpu in gpu_from_context: + ret = f'gpu::{TARGET_GPUS_DICT[gpu]}' + + if ret == "gpu::generic": + print(f"WARNING: Unrecognized GPU '{gpu_from_context}', so using gpu::generic", file=sys.stderr, flush=True) + return ret + +@dataclass(frozen=True, eq=True) +class Target: + """ + Data class describing the target of the benchmark + """ + gen: str = "gen::unknown" + arch: str = "target_arch::unknown" + gpu: str = "gpu::generic" + rep: str = "rep::amdgcn" + + def __iter__(self): + yield from (self.gen, self.arch, self.gpu, self.rep) + + def as_str(self) -> str: + return ", ".join(str(v) for v in self) + + @classmethod + def from_string(self, raw: str) -> "Target": + """Create Target instance from a raw string with key::value pairs""" + # Split by commas and strip whitespace/newlines + parts = [p.strip() for p in raw.split(",") if p.strip()] + + # Default data + data = { + "gen": self.gen, + "arch": self.arch, + "gpu": self.gpu, + "rep": self.rep, + } + + # Assign fields based on prefixes + for part in parts: + if part.startswith("gen::"): + data["gen"] = part + elif part.startswith("target_arch::"): + data["arch"] = part + elif part.startswith("gpu::"): + data["gpu"] = part + elif part.startswith("rep::"): + data["rep"] = part + + return self(**data) + + +def get_target(context): + arch = f"target_arch::{context["hdp_gcn_arch_name"].split(":")[0]}" + gpu = get_target_gpu_from_context(context) + gen = get_gen_from_architecture(arch) + rep = "rep::amdgcn" + return Target(gen, arch, gpu, rep) class colors: OK = "\033[92m" @@ -90,7 +197,8 @@ def add_new_contenders( new_config: Dict[str, Any], new_alg_data: Dict[str, Any], score_assigner: Callable[[List[Dict[str, Any]]], None], - contenders: Dict[Tuple[Any, str], Contender], + contenders: Dict[Target, Dict[Any, Contender]], + picker_strings: Dict[Tuple[Target, str], str], improvement_threshold_percentage: float, ) -> bool: global warnings, total_specializations, improvement_count, new_specialization_count, noisy_rejected_count, slower_rejected_count, marginal_improvements_rejected_count @@ -100,23 +208,25 @@ def add_new_contenders( Z_SCORE = 1.96 # 95% confidence # Collect all rows first to determine max widths - for arch, new_arch_specializations in new_config["specializations"].items(): - if arch not in new_alg_data and ( - arch != "unknown" or "gfx908" not in new_alg_data - ): + for target, new_target_specializations in new_config["specializations"].items(): + if Target() == target: + continue + if target not in new_alg_data: warnings.append( - f"{colors.WARN}The new JSON data is missing {arch} for {algorithm_name}{colors.END_COLOR}" + f"{colors.WARN}The new JSON data is missing {target} for {algorithm_name}{colors.END_COLOR}" ) continue - # create_optimization.py its create_config_file_content() chose to make "unknown" a copy of "gfx908" - new_arch_data = new_alg_data["gfx908" if arch == "unknown" else arch] - for instance_key, new_instance_data in new_arch_specializations.items(): - if instance_key not in new_arch_data: + new_target_data = new_alg_data[target] + for instance_key, new_instance_data in list(new_target_specializations.items()): + if instance_key in {"begin_of_picker", "end_of_picker"}: + picker_strings[(target, instance_key)] = new_instance_data + continue + if instance_key not in new_target_data: sys.exit( - f"{colors.FAIL}The new JSON data is missing {arch} specialization '{stringify_instance_key(instance_key)}' for {algorithm_name}{colors.END_COLOR}" + f"{colors.FAIL}The new JSON data is missing {target} specialization '{stringify_instance_key(instance_key)}' for {algorithm_name}{colors.END_COLOR}" ) - new_instances = new_arch_data[instance_key] + new_instances = new_target_data[instance_key] add_base_args(new_instance_data["base_args"], new_instances) score_assigner(new_instances) @@ -125,15 +235,13 @@ def add_new_contenders( total_specializations += 1 row = {} - row["arch"] = str(arch) + row["target"] = target.as_str() row["key"] = stringify_instance_key(instance_key) row["new_family_index"] = new_best_instance["family_index"] row["new_bps"] = f"{new_best_instance['bytes_per_second']:.2e}" - key = (instance_key, arch) - # If there is no old config specialization, we always accept the new one. - if key not in contenders: + if target not in contenders or instance_key not in contenders[target]: status = "New" colored_status = f"{colors.OK}{status}{colors.END_COLOR}" row["status"] = colored_status @@ -147,12 +255,14 @@ def add_new_contenders( rows.append(row) new_specialization_count += 1 improved = True - contenders[key] = Contender( + if target not in contenders: + contenders[target] = {} + contenders[target][instance_key] = Contender( instance=new_best_instance, string=new_instance_data["string"] ) continue - old_best_instance = contenders[key].instance + old_best_instance = contenders[target][instance_key].instance row["old_family_index"] = old_best_instance.get("family_index", "-") row["old_bps"] = ( @@ -217,7 +327,7 @@ def add_new_contenders( rows.append(row) improvement_count += 1 improved = True - contenders[key] = Contender( + contenders[target][instance_key] = Contender( instance=new_best_instance, string=new_instance_data["string"] ) @@ -225,7 +335,7 @@ def add_new_contenders( ("status", f"Status of {algorithm_name}"), ("noise", "Noise (old/new)"), ("bps", "Bytes/sec (old/new)"), - ("arch", "Arch"), + ("target", "Target"), ("key", "Specialization"), ("family_index", "Family index (old/new)"), ] @@ -281,35 +391,40 @@ def get_old_contenders( old_config: Dict[str, Any], old_alg_data: Dict[str, Any], score_assigner: Callable[[List[Dict[str, Any]]], None], -) -> Dict[Tuple[Any, str], Contender]: +): global warnings - contenders: Dict[Tuple[Any, str], Contender] = {} + contenders: Dict[Target, Dict[Any, Contender]] = {} + picker_strings: Dict[Tuple[str, str], str] = {} old_config.setdefault("specializations", {}) - for arch, old_arch_specializations in old_config["specializations"].items(): + for target, old_target_specializations in old_config["specializations"].items(): + if Target() == target: + continue + contenders[target] = {} # Always keep every old specialization, even if missing in old_alg_data - if arch not in old_alg_data and ( - arch != "unknown" or "gfx908" not in old_alg_data - ): + if target not in old_alg_data: warnings.append( - f"{colors.WARN}The old JSON data is missing {arch} for {algorithm_name}{colors.END_COLOR}" + f"{colors.WARN}The old JSON data is missing {target} for {algorithm_name}{colors.END_COLOR}" ) - old_arch_data = {} + old_target_data = {} else: - old_arch_data = old_alg_data["gfx908" if arch == "unknown" else arch] + old_target_data = old_alg_data[target] - for instance_key, old_instance_data in old_arch_specializations.items(): - if instance_key not in old_arch_data: - # If old_arch_data is falsy, then a warning was already printed - if old_arch_data: + for instance_key, old_instance_data in old_target_specializations.items(): + if instance_key in {"begin_of_picker", "end_of_picker"}: + picker_strings[(target, instance_key)] = old_instance_data + continue + if instance_key not in old_target_data: + # If old_target_data is falsy, then a warning was already printed + if old_target_data: warnings.append( - f"{colors.WARN}The old JSON data is missing {arch} specialization '{stringify_instance_key(instance_key)}' for {algorithm_name}{colors.END_COLOR}" + f"{colors.WARN}The old JSON data is missing {target} specialization '{stringify_instance_key(instance_key)}' for {algorithm_name}{colors.END_COLOR}" ) old_instances = [] else: - old_instances = old_arch_data[instance_key] + old_instances = old_target_data[instance_key] add_base_args(old_instance_data["base_args"], old_instances) score_assigner(old_instances) @@ -319,17 +434,38 @@ def get_old_contenders( get_best_instance(old_instances) if old_instances else {} ) - key = (instance_key, arch) - contenders[key] = Contender( + contenders[target][instance_key] = Contender( instance=old_best_instance, string=old_instance_data["string"] ) - return contenders + return (contenders, picker_strings) + +def get_comp_targets(old_config: Dict[str, Any], new_config: Dict[str, Any], algorithm_name: str): + ret = f"// All existing configs\nusing {algorithm_name}_targets = comp_targets<" + + unique_targets = [] + for target in new_config["specializations"].keys(): + if target != Target() and target not in unique_targets: + unique_targets.append(target) + for target in old_config["specializations"].keys(): + if target != Target() and target not in unique_targets: + unique_targets.append(target) + + for unique_target in unique_targets: + ret += f"comp_target<{unique_target.as_str()}>," + + ret += f"comp_target<{Target().as_str()}>>;" + return ret def get_best_instance(instances): return max(instances, key=lambda x: x["score"]) +def parse_args(base_args): + numbers = [] + for item in base_args: + numbers.extend(map(int, re.findall(r'\d+', item))) + return numbers def add_base_args(base_args, instances): """ @@ -337,8 +473,9 @@ def add_base_args(base_args, instances): """ for instance in instances: if instance["algo"] in {"merge_sort_block_sort", "radix_sort_block_sort"}: - instance["bs"] = int(base_args[0]) - instance["ipt"] = int(base_args[1]) + args = parse_args(base_args) + instance["bs"] = int(args[0]) + instance["ipt"] = int(args[1]) def score_assigner_default(instances: List[Dict[str, Any]]) -> None: @@ -399,7 +536,7 @@ def generate_improved_configs( new_alg_data = new_data.get(algorithm_name, {}) score_assigner = get_score_assigner(algorithm_name) - contenders = get_old_contenders( + contenders, picker_strings = get_old_contenders( algorithm_name, old_config or {}, old_alg_data or {}, score_assigner ) improved = add_new_contenders( @@ -408,6 +545,7 @@ def generate_improved_configs( new_alg_data, score_assigner, contenders, + picker_strings, improvement_threshold_percentage, ) if not improved: @@ -420,14 +558,32 @@ def generate_improved_configs( if old_config else new_config["start_of_config"] ) - for contender in contenders.values(): - f.write(contender.string) + # Add every picker for every different target + for target, target_contenders in contenders.items(): + f.write(picker_strings[(target, "begin_of_picker")]) + for contender in target_contenders.values(): + f.write(contender.string) + f.write(picker_strings[(target, "end_of_picker")]) + + # Add unknown target fallback case + unknown_target_contender = ( + old_config["specializations"][Target()]["begin_of_picker"] + if old_config + else new_config["specializations"][Target()]["begin_of_picker"] + ).lstrip("\n") + + f.write(unknown_target_contender) + + # Add comp_targets + comp_targets = get_comp_targets(old_config, new_config, algorithm_name) + f.write(comp_targets) + # Remove leading newlines from end_of_config to avoid triple newlines end = ( old_config["end_of_config"] if old_config else new_config["end_of_config"] - ).lstrip("\n") + ) if not end.endswith("\n"): end += "\n" f.write(end) @@ -437,24 +593,30 @@ def extract_template_arguments(code: str) -> list[str]: """ Extracts the top-level template arguments from a base config class in a C++ struct specialization. - Assumes the input contains a line ending in `: XXX_config<...> {`, where `XXX_config` is - the name of the base config class and the angle brackets contain template arguments. + Assumes the input contains a line ending in `: XXX_config_params{...} {`, where `XXX_config_params` is + the name of the base config params class and the curly brackets contain template arguments. Args: code (str): A string containing the C++ struct definition. Returns: List[str]: A list of top-level template arguments as strings. - Nested templates like `kernel_config<1024, 1>` are preserved as single items. + Nested structs like `kernel_config_params{1024, 1}` are preserved as single items. Example: Input: - 'struct foo : some_config<256, kernel_config<128, 2>, 1 << 2> {' + ``` + if constexpr(true) + { + return some_config_params{256, kernel_config_params{128, 2}, 1 << 2}; + } + ``` Output: - ['256', 'kernel_config<128, 2>', '1 << 2'] + ['256', 'kernel_config_params{128, 2}', '1 << 2'] """ # Match the base class ending in `_config<...> {` - match = re.search(r":\s*\w+_config<(.+?)>\s*{", code, re.DOTALL) + match = re.search(r"return\s+\w+_config_params\s*\{\s*([\s\S]*?)\s*\};", code, re.DOTALL) + if not match: return [] @@ -466,15 +628,15 @@ def extract_template_arguments(code: str) -> list[str]: depth = 0 i = 0 while i < len(template_args): - if template_args[i : i + 2] in ("<<", ">>"): + if template_args[i : i + 2] in ("{{", "}}"): current += template_args[i : i + 2] i += 2 continue char = template_args[i] - if char == "<": + if char == "{": depth += 1 - elif char == ">": + elif char == "}": depth -= 1 if char == "," and depth == 0: @@ -501,21 +663,24 @@ def get_specialization_key(specialization_string: str) -> Any: # 'select_flag': { # 'full_name': 'device_select_flag.hpp', # 'start_of_config': '// Copyright ...', +# 'comp_targets': 'using select_flag_targets = comp_targets< ... >;' # 'end_of_config': '} // end namespace detail ...', # 'specializations': { -# 'gfx1200': { +# Target(gen::rdna2, target_arch::gfx1030, gpu::v620, rep::amdgcn): { +# 'begin_of_picker': 'template constexpr auto select_flag_config_picker() ... {\n', +# 'end_of_picker': `// Default case if none of the conditions match\nreturn partition_config_params_base();}\n`, # Instance(key_type='double'): { -# 'string': '// Based on key_type = double\n ... select_config<512, kernel_config<128, 2>>\n{};\n\n', +# 'string': '// Based on key_type = double\n ... partition_config_params{512, kernel_config_params{128, 2}};\n', # 'base_args': [ # '512', -# 'kernel_config<128, 2>' +# 'kernel_config_params{128, 2}' # ] # }, # Instance(key_type='float'): { -# 'string': '// Based on key_type = float\n ... select_config<1024, kernel_config<128, 2>>\n{};\n\n', +# 'string': '// Based on key_type = float\n ... partition_config_params{1024, kernel_config_params{128, 2}};\n', # 'base_args': [ # '1024', -# 'kernel_config<128, 2>' +# 'kernel_config_params{128, 2}' # ] # } # } @@ -531,32 +696,55 @@ def read_configs(dir_path: Path) -> Dict[str, Any]: config["full_name"] = hpp_path.name text = hpp_path.read_text() - matches = re.split(r"(// Based on .*?{\s*};\n\n)", text, flags=re.DOTALL) - matches = list(filter(None, matches)) - config["start_of_config"] = matches[0] - config["end_of_config"] = matches[-1] + pickers = re.split(r"(template\s*<[\s\S]*?>\s*\{[\s\S]*?}\n\n)", text, flags=re.DOTALL) + pickers = list(filter(None, pickers)) + # Remove spaces from picker functions to have a more consistent input. + pickers = [picker for picker in pickers if picker.strip() != ""] + + # All the code before the picker functions. + config["start_of_config"] = pickers[0] + + # The code after the picker functions need to be divided in comp_targets and the end. + comp_target = re.split(r"(// .*\nusing[\s\S]*?;)", pickers[-1], flags=re.DOTALL) + comp_target = list(filter(None, comp_target)) + config["comp_targets"] = comp_target[0] + config["end_of_config"] = comp_target[-1] all_specializations = config.setdefault("specializations", {}) - specialization_strings = matches[1:-1] - for specialization_string in specialization_strings: - arch_match = re.search(r"target_arch::(.*?)\)", specialization_string) - if not arch_match: - sys.exit( - f"{colors.FAIL}Could not find arch in specialization: {specialization_string}{colors.END_COLOR}" - ) - arch = arch_match.group(1) + picker_strings = pickers[1:-1] + for picker_string in picker_strings: + target_match = re.search(r"comp_target<((.||\s)*?)>", picker_string) - specialization_key = get_specialization_key(specialization_string) - arch_specializations = all_specializations.setdefault(arch, {}) - if specialization_key in arch_specializations: + if not target_match: sys.exit( - f"{colors.FAIL}Specialization key duplicate '{specialization_key}'{colors.END_COLOR}" + f"{colors.FAIL}Could not find arch in specialization: {picker_string}{colors.END_COLOR}" ) - arch_specializations[specialization_key] = { - "string": specialization_string, - "base_args": extract_template_arguments(specialization_string), - } + + target = Target.from_string(target_match.group(1)) + + specializations = re.split(r"(\s*// Based on[\s\S]*?\}\n(?=\s*//))", picker_string, flags=re.DOTALL) + specializations = list(filter(None, specializations)) + target_specializations = all_specializations.setdefault(target, {}) + target_specializations["begin_of_picker"] = specializations[0] + target_specializations["end_of_picker"] = specializations[-1] + + for specialization_string in specializations[1:-1]: + specialization_key = get_specialization_key(specialization_string) + + if specialization_key in target_specializations: + sys.exit( + f"{colors.FAIL}Specialization key duplicate '{specialization_key}'{colors.END_COLOR}" + ) + args = extract_template_arguments(specialization_string) + if args is []: + sys.exit( + f"{colors.FAIL}Arguments are empty '{specialization_key}'{colors.END_COLOR}" + ) + target_specializations[specialization_key] = { + "string": specialization_string, + "base_args": args, + } return configs @@ -575,7 +763,7 @@ def get_instance_key(instanced_types: Dict[str, Any], selectors: List[str]) -> A # This is the format of the returned dictionary: # { # 'select_flag': { -# 'gfx1200': { +# Target(gen::rdna2, target_arch::gfx1030, gpu::v620, rep::amdgcn): { # Instance(key_type='double'): [ # { 'items_per_second': 200, 'segment_count': 10 }, # { 'items_per_second': 300, 'segment_count': 20 } @@ -592,7 +780,7 @@ def read_data(dir_path: Path, selectors: Dict[str, List[str]]) -> Dict[str, Any] for json_path in dir_path.rglob("*.json"): json_data = json.loads(json_path.read_text()) - arch = json_data["context"]["hdp_gcn_arch_name"].split(":")[0] + target = get_target(json_data["context"]) for benchmark in json_data["benchmarks"]: name = benchmark["name"] @@ -614,14 +802,14 @@ def read_data(dir_path: Path, selectors: Dict[str, List[str]]) -> Dict[str, Any] algorithm_name += "_" + data["subalgo"] alg_data = all_data.setdefault(algorithm_name, {}) - arch_data = alg_data.setdefault(arch, defaultdict(list)) + target_data = alg_data.setdefault(target, defaultdict(list)) if algorithm_name not in selectors: sys.exit( f"{colors.FAIL}No selectors found for algorithm '{algorithm_name}' in the selectors JSON{colors.END_COLOR}" ) instance_key = get_instance_key(data, selectors[algorithm_name]) - arch_data[instance_key].append(data) + target_data[instance_key].append(data) return all_data @@ -706,21 +894,23 @@ def add_arguments(parser: StrictArgumentParser) -> None: class TestExtractTemplateArguments(unittest.TestCase): def test_complex(self): specialization = """ - struct default_radix_sort_onesweep_config< - static_cast(target_arch::gfx1030), - key_type, - value_type, - std::enable_if_t<(bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) - && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) - && (sizeof(value_type) > 8))>> - : radix_sort_onesweep_config, (1 << 17) + 1 >> 2 + 70000, - 8, - block_radix_rank_algorithm::match> - {}; + // Based on key_type = double, value_type = rocprim::int128_t + if constexpr((bool(rocprim::is_floating_point::value) && (sizeof(key_type) <= 8) + && (sizeof(key_type) > 4) && (sizeof(value_type) <= 16) + && (sizeof(value_type) > 8))) + { + return radix_sort_onesweep_config_params{ + kernel_config_params{1024, 1}, + (1 << 17) + 1 >> 2 + 70000, + 8, + block_radix_rank_algorithm::match + }; + } + // Needs a comment after it. """ expected = [ - "kernel_config<1024, 1>", + "kernel_config_params{1024, 1}", "(1 << 17) + 1 >> 2 + 70000", "8", "block_radix_rank_algorithm::match", @@ -753,6 +943,11 @@ def main() -> None: args.improved_configs_dir, ) + if total_specializations == 0: + warnings.append( + f"{colors.WARN}No specializations are found in the config files!{colors.END_COLOR}" + ) + if warnings: print("\n".join(warnings)) print("") From fa3b066965a97848faa425fd23629eb65698ef2a Mon Sep 17 00:00:00 2001 From: Nick Breed Date: Thu, 27 Nov 2025 14:25:59 +0000 Subject: [PATCH 14/26] Added to CHANGELOG --- projects/rocprim/CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/projects/rocprim/CHANGELOG.md b/projects/rocprim/CHANGELOG.md index 8232e505acd..940eaac0700 100644 --- a/projects/rocprim/CHANGELOG.md +++ b/projects/rocprim/CHANGELOG.md @@ -2,6 +2,12 @@ Full documentation for rocPRIM is available at [https://rocm.docs.amd.com/projects/rocPRIM/en/latest/](https://rocm.docs.amd.com/projects/rocPRIM/en/latest/). +## rocPRIM x.y.z for ROCm 8.0 + +### Optimizations + +* Updated config system to pick better fallback configs for untuned GPUs. + ## rocPRIM 4.2.0 for ROCm 7.2 ### Added From 0b5a48ee758e0d97b2c37aec1f55c24953f2b5c9 Mon Sep 17 00:00:00 2001 From: Nick Breed Date: Fri, 28 Nov 2025 08:08:55 +0000 Subject: [PATCH 15/26] Cleanup target_config --- .../include/rocprim/device/config_types.hpp | 58 ++++++------------- .../test/rocprim/test_config_dispatch.cpp | 32 +++++----- 2 files changed, 36 insertions(+), 54 deletions(-) diff --git a/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp b/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp index 19ff896c9a1..3923649b1ab 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp @@ -382,28 +382,20 @@ constexpr void for_each_arch(F&& f) std::make_index_sequence{}); } -constexpr arch::wavefront::target arch_wavefront_size(const target_arch target_arch) +constexpr arch::wavefront::target gen_wavefront_size(const gen gen) { - switch(target_arch) + switch(gen) { - case target_arch::unknown: return arch::wavefront::get_target(); - case target_arch::gfx803: return arch::wavefront::target::size64; - case target_arch::gfx900: return arch::wavefront::target::size64; - case target_arch::gfx906: return arch::wavefront::target::size64; - case target_arch::gfx908: return arch::wavefront::target::size64; - case target_arch::gfx90a: return arch::wavefront::target::size64; - case target_arch::gfx942: return arch::wavefront::target::size64; - case target_arch::gfx950: return arch::wavefront::target::size64; - case target_arch::gfx1030: return arch::wavefront::target::size32; - case target_arch::gfx1100: return arch::wavefront::target::size32; - case target_arch::gfx1102: return arch::wavefront::target::size32; - case target_arch::gfx1152: return arch::wavefront::target::size32; - case target_arch::gfx1153: return arch::wavefront::target::size32; - case target_arch::gfx1200: return arch::wavefront::target::size32; - case target_arch::gfx1201: return arch::wavefront::target::size32; - - // Unreachable - case target_arch::invalid: return arch::wavefront::target::dynamic; + case gen::unknown: return arch::wavefront::get_target(); + case gen::gcn3: + case gen::gcn5: + case gen::cdna1: + case gen::cdna2: + case gen::cdna3: + case gen::cdna4: return arch::wavefront::target::size64; + case gen::rdna2: + case gen::rdna3: + case gen::rdna4: return arch::wavefront::target::size32; } } @@ -516,9 +508,9 @@ constexpr typename Selector::param_type get_config(Config config, target t) template struct target_config { - constexpr static auto params = get_config(Config{}, target{Target{}}); - constexpr static auto wavefront = arch_wavefront_size(Target::i); - constexpr static auto arch = Target::i; + constexpr static target config_target = target{Target{}}; + constexpr static auto params = get_config(Config{}, config_target); + constexpr static auto wavefront = gen_wavefront_size(Target::g); }; template @@ -536,8 +528,7 @@ template - class LaunchSelector, - bool PassTarget = false> + class LaunchSelector> ROCPRIM_KERNEL __launch_bounds__((LaunchSelector::block_size)) void trampoline_kernel(Kernel kernel) { @@ -550,14 +541,7 @@ void trampoline_kernel(Kernel kernel) if constexpr(Target::i == device_arch_target.i) #endif { - if constexpr(PassTarget) - { - kernel(ArchConfig{}, Target{}); - } - else - { - kernel(ArchConfig{}); - } + kernel(ArchConfig{}); } #if !defined(ROCPRIM_TARGET_SPIRV) || ROCPRIM_TARGET_SPIRV == 0 else @@ -570,7 +554,6 @@ void trampoline_kernel(Kernel kernel) template class LaunchSelector = default_config_static_selector, - bool PassTarget = false, class Kernel> auto make_launch_plan(target target_current, Kernel kernel) -> launch_plan { @@ -590,8 +573,7 @@ auto make_launch_plan(target target_current, Kernel kernel) -> launch_plan; + LaunchSelector>; } }); @@ -601,13 +583,11 @@ auto make_launch_plan(target target_current, Kernel kernel) -> launch_plan class LaunchSelector = default_config_static_selector, - bool PassTarget = false, class Kernel> hipError_t execute_launch_plan( target t, Kernel kernel, dim3 grid_size, dim3 block_size, size_t shmem, hipStream_t stream) { - const auto launch_plan - = make_launch_plan(t, kernel); + const auto launch_plan = make_launch_plan(t, kernel); launch_plan.launch(grid_size, block_size, shmem, stream); return hipGetLastError(); } diff --git a/projects/rocprim/test/rocprim/test_config_dispatch.cpp b/projects/rocprim/test/rocprim/test_config_dispatch.cpp index 547c85e993a..9a10e590800 100644 --- a/projects/rocprim/test/rocprim/test_config_dispatch.cpp +++ b/projects/rocprim/test/rocprim/test_config_dispatch.cpp @@ -414,17 +414,16 @@ TEST(RocprimConfigDispatchTests, ExecuteLaunchPlan) target* d_output; HIP_CHECK(hipMalloc(&d_output, sizeof(target))); - auto kernel = [=](auto arch_config, auto t) - { - (void)arch_config; - *d_output = target{t}; - }; + auto kernel = [=](auto arch_config) { *d_output = decltype(arch_config)::config_target; }; - HIP_CHECK( - (execute_launch_plan, - default_config_static_selector, - true>(current_target, kernel, dim3(1), dim3(block_size), 0, stream))); + HIP_CHECK((execute_launch_plan, + default_config_static_selector>(current_target, + kernel, + dim3(1), + dim3(block_size), + 0, + stream))); // Impossible target target h_output{gen::cdna4, target_arch::gfx942, gpu::rx6900, rep::spirv}; @@ -432,11 +431,14 @@ TEST(RocprimConfigDispatchTests, ExecuteLaunchPlan) // Compared to targets with only unknown inside. ASSERT_EQ(target(), h_output); - HIP_CHECK( - (execute_launch_plan, - default_config_static_selector, - true>(current_target, kernel, dim3(1), dim3(block_size), 0, stream))); + HIP_CHECK((execute_launch_plan, + default_config_static_selector>(current_target, + kernel, + dim3(1), + dim3(block_size), + 0, + stream))); // Impossible target h_output = target{gen::cdna4, target_arch::gfx942, gpu::rx6900, rep::spirv}; From 293777c47dd29c7699bf2342e4b4861cab69ae34 Mon Sep 17 00:00:00 2001 From: Nick Breed Date: Tue, 2 Dec 2025 13:53:27 +0000 Subject: [PATCH 16/26] Fix base block methods adjacent_difference_config --- .../include/rocprim/device/detail/device_config_helper.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/device_config_helper.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/device_config_helper.hpp index 0059f80eabf..2da4f8f8585 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/device_config_helper.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/device_config_helper.hpp @@ -828,8 +828,8 @@ constexpr histogram_config_params histogram_config_params_base() struct adjacent_difference_config_params { kernel_config_params kernel_config{}; - ::rocprim::block_load_method block_load_method{}; - ::rocprim::block_store_method block_store_method{}; + ::rocprim::block_load_method block_load_method = block_load_method::block_load_transpose; + ::rocprim::block_store_method block_store_method = block_store_method::block_store_transpose; }; } // namespace detail From 6448f1303e808c0bf462e677a2000cfbc71ebbe2 Mon Sep 17 00:00:00 2001 From: Nick Breed Date: Thu, 4 Dec 2025 08:23:52 +0000 Subject: [PATCH 17/26] Clear previous caches before current one is created --- .../benchmark_device_histogram.parallel.hpp | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/projects/rocprim/benchmark/benchmark_device_histogram.parallel.hpp b/projects/rocprim/benchmark/benchmark_device_histogram.parallel.hpp index e3ff908c714..b66821fb76c 100644 --- a/projects/rocprim/benchmark/benchmark_device_histogram.parallel.hpp +++ b/projects/rocprim/benchmark/benchmark_device_histogram.parallel.hpp @@ -235,6 +235,16 @@ struct device_histogram_benchmark : public benchmark_utils::autotune_interface }; }; + // Clear caches for other types that are either empty or already done. + clear_other_caches(); + const std::size_t size = bytes / Channels; size_t temporary_storage_bytes = 0; @@ -318,16 +328,6 @@ struct device_histogram_benchmark : public benchmark_utils::autotune_interface { HIP_CHECK(hipFree(d_histogram[channel])); } - - // Clear caches for other types that are either empty or already done. - clear_other_caches(); } }; From a6a1cf654d16ac24d836e5622439620f2412bd3f Mon Sep 17 00:00:00 2001 From: Nick Breed Date: Fri, 5 Dec 2025 07:49:16 +0000 Subject: [PATCH 18/26] Give device_histogram the same fallback as previous configs system and fix predicate_flag config choosing error. --- .../device/detail/config/device_histogram.hpp | 11 +++++++++++ .../include/rocprim/device/device_partition.hpp | 12 ++++++------ .../rocprim/device/device_partition_config.hpp | 4 ++-- 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_histogram.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_histogram.hpp index e5f67d94e9a..052ff763ec0 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_histogram.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_histogram.hpp @@ -940,6 +940,16 @@ constexpr auto histogram_config_picker() -> std::enable_if_t< return histogram_config_params_base(); } +template +constexpr auto histogram_config_picker() -> std::enable_if_t< + std::is_same>::value, + histogram_config_params> +{ + // Same fallback as previous config system. + return histogram_config_params_base(); +} + template constexpr auto histogram_config_picker() -> std::enable_if_t< std::is_same std::enable_if_t< using histogram_targets = comp_targets, comp_target, + comp_target, comp_target, comp_target, comp_target, diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_partition.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_partition.hpp index d3b384fbc49..55e58c1a0a3 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_partition.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_partition.hpp @@ -138,8 +138,6 @@ inline hipError_t partition_impl(void* temporary_storage, using scan_state_type = detail::lookback_scan_state; using block_id_type = detail::block_id_wrapper; - using selector = partition_config_selector; - constexpr bool write_only_selected = SubAlgo == partition_subalgo::select_flag || SubAlgo == partition_subalgo::select_predicate @@ -160,6 +158,12 @@ inline hipError_t partition_impl(void* temporary_storage, ? select_method::predicated_flag : (is_flag ? select_method::flag : select_method::predicate)); + using flag_type = + typename std::conditional::value_type, + bool>::type; + using selector = partition_config_selector; + detail::target_arch target_arch; ROCPRIM_RETURN_ON_ERROR(host_target_arch(stream, target_arch)); detail::gpu target_gpu; @@ -198,10 +202,6 @@ inline hipError_t partition_impl(void* temporary_storage, // vsmem size void* vsmem = nullptr; size_t virtual_shared_memory_size = 0; - using flag_type = - typename std::conditional::value_type, - bool>::type; virtual_shared_memory_size = get_partition_vsmem_size_per_block +template struct partition_config_selector { using targets = typename decltype(algo_target_type())::type; @@ -128,7 +128,7 @@ struct partition_config_selector } else if constexpr(SubAlgo == partition_subalgo::select_predicated_flag) { - return select_predicated_flag_config_picker(); + return select_predicated_flag_config_picker(); } else if constexpr(SubAlgo == partition_subalgo::select_unique) { From 8afe9ae7501b9917a4041c352841083c9626c264 Mon Sep 17 00:00:00 2001 From: Nick Breed Date: Thu, 11 Dec 2025 13:53:31 +0000 Subject: [PATCH 19/26] Resolve "Fix generic compile target new config system" --- .../rocprim/include/rocprim/config.hpp | 125 ++++++++++-------- .../include/rocprim/device/config_types.hpp | 9 +- .../device/detail/device_radix_sort.hpp | 4 +- .../include/rocprim/thread/thread_load.hpp | 4 +- .../include/rocprim/thread/thread_store.hpp | 4 +- .../rocprim/warp/detail/warp_reduce_dpp.hpp | 19 ++- .../rocprim/warp/detail/warp_scan_dpp.hpp | 19 ++- .../test/rocprim/test_config_dispatch.cpp | 20 ++- 8 files changed, 122 insertions(+), 82 deletions(-) diff --git a/projects/rocprim/rocprim/include/rocprim/config.hpp b/projects/rocprim/rocprim/include/rocprim/config.hpp index 50a65505db7..2b9c5d40ddc 100644 --- a/projects/rocprim/rocprim/include/rocprim/config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/config.hpp @@ -152,88 +152,99 @@ #if !defined(ROCPRIM_THREAD_STORE_USE_CACHE_MODIFIERS) #define ROCPRIM_THREAD_STORE_USE_CACHE_MODIFIERS 1 #endif - #define IS_CDNA3() \ - __builtin_amdgcn_processor_is("gfx942") || __builtin_amdgcn_processor_is("gfx950") \ - || __builtin_amdgcn_processor_is("gfx9-4-generic") - #define IS_CDNA2() __builtin_amdgcn_processor_is("gfx90a") - #define IS_CDNA1() __builtin_amdgcn_processor_is("gfx908") - #define IS_GCN5() \ - __builtin_amdgcn_processor_is("gfx900") || __builtin_amdgcn_processor_is("gfx902") \ - || __builtin_amdgcn_processor_is("gfx904") || __builtin_amdgcn_processor_is("gfx906") \ - || __builtin_amdgcn_processor_is("gfx90c") \ - || __builtin_amdgcn_processor_is("gfx9-generic") - #define IS_RDNA4() \ - __builtin_amdgcn_processor_is("gfx1200") || __builtin_amdgcn_processor_is("gfx1201") \ - || __builtin_amdgcn_processor_is("gfx12-generic") // TODO: Re-enable gfx1250 when supported by compiler - #define IS_RDNA3() \ - __builtin_amdgcn_processor_is("gfx1100") || __builtin_amdgcn_processor_is("gfx1101") \ - || __builtin_amdgcn_processor_is("gfx1102") \ - || __builtin_amdgcn_processor_is("gfx1103") \ - || __builtin_amdgcn_processor_is("gfx1152") \ - || __builtin_amdgcn_processor_is("gfx1153") \ - || __builtin_amdgcn_processor_is("gfx11-generic") - #define IS_RDNA2() \ - __builtin_amdgcn_processor_is("gfx1030") || __builtin_amdgcn_processor_is("gfx1031") \ - || __builtin_amdgcn_processor_is("gfx1032") \ - || __builtin_amdgcn_processor_is("gfx1033") \ - || __builtin_amdgcn_processor_is("gfx1034") \ - || __builtin_amdgcn_processor_is("gfx1035") \ - || __builtin_amdgcn_processor_is("gfx1036") \ - || __builtin_amdgcn_processor_is("gfx10-3-generic") - #define IS_RDNA1() \ - __builtin_amdgcn_processor_is("gfx1010") || __builtin_amdgcn_processor_is("gfx1011") \ - || __builtin_amdgcn_processor_is("gfx1012") \ - || __builtin_amdgcn_processor_is("gfx1013") \ - || __builtin_amdgcn_processor_is("gfx10-1-generic") - #define IS_GCN3() \ - __builtin_amdgcn_processor_is("gfx801") || __builtin_amdgcn_processor_is("gfx802") \ - || __builtin_amdgcn_processor_is("gfx803") || __builtin_amdgcn_processor_is("gfx805") \ - || __builtin_amdgcn_processor_is("gfx810") + #define ROCPRIM_IS_CDNA3() \ + (__builtin_amdgcn_processor_is("gfx942") || __builtin_amdgcn_processor_is("gfx950") \ + || __builtin_amdgcn_processor_is("gfx9-4-generic")) + #define ROCPRIM_IS_CDNA2() (__builtin_amdgcn_processor_is("gfx90a")) + #define ROCPRIM_IS_CDNA1() (__builtin_amdgcn_processor_is("gfx908")) + #define ROCPRIM_IS_GCN5() \ + (__builtin_amdgcn_processor_is("gfx900") || __builtin_amdgcn_processor_is("gfx902") \ + || __builtin_amdgcn_processor_is("gfx904") || __builtin_amdgcn_processor_is("gfx906") \ + || __builtin_amdgcn_processor_is("gfx90c") \ + || __builtin_amdgcn_processor_is("gfx9-generic")) + #define ROCPRIM_IS_RDNA4() \ + (__builtin_amdgcn_processor_is("gfx1200") || __builtin_amdgcn_processor_is("gfx1201") \ + || __builtin_amdgcn_processor_is( \ + "gfx12-generic")) // TODO: Re-enable gfx1250 when supported by compiler + #define ROCPRIM_IS_RDNA3() \ + (__builtin_amdgcn_processor_is("gfx1100") || __builtin_amdgcn_processor_is("gfx1101") \ + || __builtin_amdgcn_processor_is("gfx1102") || __builtin_amdgcn_processor_is("gfx1103") \ + || __builtin_amdgcn_processor_is("gfx1152") || __builtin_amdgcn_processor_is("gfx1153") \ + || __builtin_amdgcn_processor_is("gfx11-generic")) + #define ROCPRIM_IS_RDNA2() \ + (__builtin_amdgcn_processor_is("gfx1030") || __builtin_amdgcn_processor_is("gfx1031") \ + || __builtin_amdgcn_processor_is("gfx1032") || __builtin_amdgcn_processor_is("gfx1033") \ + || __builtin_amdgcn_processor_is("gfx1034") || __builtin_amdgcn_processor_is("gfx1035") \ + || __builtin_amdgcn_processor_is("gfx1036") \ + || __builtin_amdgcn_processor_is("gfx10-3-generic")) + #define ROCPRIM_IS_RDNA1() \ + (__builtin_amdgcn_processor_is("gfx1010") || __builtin_amdgcn_processor_is("gfx1011") \ + || __builtin_amdgcn_processor_is("gfx1012") || __builtin_amdgcn_processor_is("gfx1013") \ + || __builtin_amdgcn_processor_is("gfx10-1-generic")) + #define ROCPRIM_IS_GCN3() \ + (__builtin_amdgcn_processor_is("gfx801") || __builtin_amdgcn_processor_is("gfx802") \ + || __builtin_amdgcn_processor_is("gfx803") || __builtin_amdgcn_processor_is("gfx805") \ + || __builtin_amdgcn_processor_is("gfx810")) + #define ROCPRIM_IS_GENERIC() \ + (__builtin_amdgcn_processor_is("gfx9-4-generic") \ + || __builtin_amdgcn_processor_is("gfx9-generic") \ + || __builtin_amdgcn_processor_is("gfx11-generic") \ + || __builtin_amdgcn_processor_is("gfx10-3-generic") \ + || __builtin_amdgcn_processor_is("gfx10-1-generic") \ + || __builtin_amdgcn_processor_is("gfx12-generic")) #else #if defined(ROCPRIM_TARGET_CDNA3) - #define IS_CDNA3() 1 + #define ROCPRIM_IS_CDNA3() 1 #else - #define IS_CDNA3() 0 + #define ROCPRIM_IS_CDNA3() 0 #endif #if defined(ROCPRIM_TARGET_CDNA2) - #define IS_CDNA2() 1 + #define ROCPRIM_IS_CDNA2() 1 #else - #define IS_CDNA2() 0 + #define ROCPRIM_IS_CDNA2() 0 #endif #if defined(ROCPRIM_TARGET_CDNA1) - #define IS_CDNA1() 1 + #define ROCPRIM_IS_CDNA1() 1 #else - #define IS_CDNA1() 0 + #define ROCPRIM_IS_CDNA1() 0 #endif #if defined(ROCPRIM_TARGET_GCN5) - #define IS_GCN5() 1 + #define ROCPRIM_IS_GCN5() 1 #else - #define IS_GCN5() 0 + #define ROCPRIM_IS_GCN5() 0 #endif #if defined(ROCPRIM_TARGET_RDNA4) - #define IS_RDNA4() 1 + #define ROCPRIM_IS_RDNA4() 1 #else - #define IS_RDNA4() 0 + #define ROCPRIM_IS_RDNA4() 0 #endif #if defined(ROCPRIM_TARGET_RDNA3) - #define IS_RDNA3() 1 + #define ROCPRIM_IS_RDNA3() 1 #else - #define IS_RDNA3() 0 + #define ROCPRIM_IS_RDNA3() 0 #endif #if defined(ROCPRIM_TARGET_RDNA2) - #define IS_RDNA2() 1 + #define ROCPRIM_IS_RDNA2() 1 #else - #define IS_RDNA2() 0 + #define ROCPRIM_IS_RDNA2() 0 #endif #if defined(ROCPRIM_TARGET_RDNA1) - #define IS_RDNA1() 1 + #define ROCPRIM_IS_RDNA1() 1 #else - #define IS_RDNA1() 0 + #define ROCPRIM_IS_RDNA1() 0 #endif #if defined(ROCPRIM_TARGET_GCN3) - #define IS_GCN3() 1 + #define ROCPRIM_IS_GCN3() 1 #else - #define IS_GCN3() 0 + #define ROCPRIM_IS_GCN3() 0 + #endif + + #if defined(__gfx9_generic__) || defined(__gfx9_4_generic__) || defined(__gfx10_1_generic__) \ + || defined(__gfx10_3_generic__) || defined(__gfx11_generic__) \ + || defined(__gfx12_generic__) + #define ROCPRIM_IS_GENERIC() 1 + #else + #define ROCPRIM_IS_GENERIC() 0 #endif #if !defined(ROCPRIM_THREAD_LOAD_USE_CACHE_MODIFIERS) @@ -267,7 +278,7 @@ #define ROCPRIM_DETAIL_HAS_DPP 1 #endif -#if (!defined(ROCPRIM_DISABLE_DPP) || ROCPRIM_DISABLE_DPP == 0) \ +#if(!defined(ROCPRIM_DISABLE_DPP) || ROCPRIM_DISABLE_DPP == 0) \ && (defined(ROCPRIM_DETAIL_HAS_DPP) && ROCPRIM_DETAIL_HAS_DPP == 1) #define ROCPRIM_DETAIL_USE_DPP 1 #else @@ -292,7 +303,7 @@ /// Quad size (group of 4 threads) #define ROCPRIM_QUAD_SIZE 4u -#if (defined(_MSC_VER) && !defined(__clang__)) || (defined(__GNUC__) && !defined(__clang__)) +#if(defined(_MSC_VER) && !defined(__clang__)) || (defined(__GNUC__) && !defined(__clang__)) #define ROCPRIM_UNROLL #define ROCPRIM_NO_UNROLL #else diff --git a/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp b/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp index 3923649b1ab..ed3b78e359b 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp @@ -538,16 +538,21 @@ void trampoline_kernel(Kernel kernel) using Targets = typename Selector::targets; // If the arch does not exist in the Targets it should run the arch for the most_common_config. constexpr target device_arch_target = most_common_config(target(device_target_arch())); + // If the build time arch from device_target_arch is a generic arch it is not the same as the runtime arch. if constexpr(Target::i == device_arch_target.i) -#endif { kernel(ArchConfig{}); } -#if !defined(ROCPRIM_TARGET_SPIRV) || ROCPRIM_TARGET_SPIRV == 0 + else if constexpr(ROCPRIM_IS_GENERIC()) + { + kernel(ArchConfig{}); + } else { __builtin_unreachable(); } +#else + kernel(ArchConfig{}); #endif } diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/device_radix_sort.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/device_radix_sort.hpp index 36c570b92fe..0bf41a523fa 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/device_radix_sort.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/device_radix_sort.hpp @@ -1276,7 +1276,7 @@ struct onesweep_iteration_helper for(unsigned int i = 0; i < ItemsPerThread; ++i) { // It only seems worse on gfx942 in some cases. - if ROCPRIM_AMDGCN_CONSTEXPR(IS_CDNA3()) + if ROCPRIM_AMDGCN_CONSTEXPR(ROCPRIM_IS_CDNA3()) { const int offset = ranks[i] - x; if(offset >= 0 && offset < static_cast(BlockSize * NKey)) @@ -1369,7 +1369,7 @@ struct onesweep_iteration_helper for(unsigned int i = 0; i < ItemsPerThread; ++i) { // It only seems worse on gfx942 in some cases. - if ROCPRIM_AMDGCN_CONSTEXPR(IS_CDNA3()) + if ROCPRIM_AMDGCN_CONSTEXPR(ROCPRIM_IS_CDNA3()) { const int offset = ranks[i] - x; if(offset >= 0 && offset < static_cast(BlockSize * NValue)) diff --git a/projects/rocprim/rocprim/include/rocprim/thread/thread_load.hpp b/projects/rocprim/rocprim/include/rocprim/thread/thread_load.hpp index f762d6b1eef..b410b91ec75 100644 --- a/projects/rocprim/rocprim/include/rocprim/thread/thread_load.hpp +++ b/projects/rocprim/rocprim/include/rocprim/thread/thread_load.hpp @@ -85,14 +85,14 @@ T asm_thread_load(void* ptr) ROCPRIM_DEVICE ROCPRIM_INLINE type asm_thread_load(void* ptr) \ { \ interim_type retval; \ - if ROCPRIM_AMDGCN_CONSTEXPR(IS_RDNA4()) \ + if ROCPRIM_AMDGCN_CONSTEXPR(ROCPRIM_IS_RDNA4()) \ { \ asm volatile(#asm_operator " %0, %1 th:TH_DEFAULT scope:SCOPE_DEV\n\t" \ "s_wait_loadcnt_dscnt(%2)" \ : "=&v"(retval) \ : "v"(ptr), "I"(0x00)); \ } \ - else if ROCPRIM_AMDGCN_CONSTEXPR(IS_CDNA3()) \ + else if ROCPRIM_AMDGCN_CONSTEXPR(ROCPRIM_IS_CDNA3()) \ { \ asm volatile(#asm_operator " %0, %1 sc0 nt\n\t" \ "s_waitcnt(%2)" \ diff --git a/projects/rocprim/rocprim/include/rocprim/thread/thread_store.hpp b/projects/rocprim/rocprim/include/rocprim/thread/thread_store.hpp index bdc1583f886..a12d6e50ed0 100644 --- a/projects/rocprim/rocprim/include/rocprim/thread/thread_store.hpp +++ b/projects/rocprim/rocprim/include/rocprim/thread/thread_store.hpp @@ -83,14 +83,14 @@ void asm_thread_store(void* ptr, T val) type val) \ { \ interim_type temp_val = *bit_cast(&val); \ - if ROCPRIM_AMDGCN_CONSTEXPR(IS_RDNA4()) \ + if ROCPRIM_AMDGCN_CONSTEXPR(ROCPRIM_IS_RDNA4()) \ { \ asm volatile(#asm_operator " %0, %1 th:TH_DEFAULT scope:SCOPE_DEV\n\t" \ "s_wait_storecnt_dscnt(%2)" \ : \ : "v"(ptr), "v"(temp_val), "I"(0x00)); \ } \ - else if ROCPRIM_AMDGCN_CONSTEXPR(IS_CDNA3()) \ + else if ROCPRIM_AMDGCN_CONSTEXPR(ROCPRIM_IS_CDNA3()) \ { \ asm volatile(#asm_operator " %0, %1 sc0 nt\n\t" \ "s_waitcnt(%2)" \ diff --git a/projects/rocprim/rocprim/include/rocprim/warp/detail/warp_reduce_dpp.hpp b/projects/rocprim/rocprim/include/rocprim/warp/detail/warp_reduce_dpp.hpp index e478d712c4f..62bfb36aa0a 100644 --- a/projects/rocprim/rocprim/include/rocprim/warp/detail/warp_reduce_dpp.hpp +++ b/projects/rocprim/rocprim/include/rocprim/warp/detail/warp_reduce_dpp.hpp @@ -94,16 +94,21 @@ class warp_reduce_dpp } #if !ROCPRIM_TARGET_SPIRV - static_assert(VirtualWaveSize <= 32, - "VirtualWaveSize > 32 is not supported without DPP broadcasts"); -#else - if constexpr(VirtualWaveSize > 32) + if constexpr(!ROCPRIM_IS_GENERIC()) { - ROCPRIM_PRINT_ERROR_ONCE( - "VirtualWaveSize > 32 is not supported without DPP broadcasts"); - return; + static_assert(VirtualWaveSize <= 32, + "VirtualWaveSize > 32 is not supported without DPP broadcasts"); } + else #endif + { + if constexpr(VirtualWaveSize > 32) + { + ROCPRIM_PRINT_ERROR_ONCE( + "VirtualWaveSize > 32 is not supported without DPP broadcasts"); + return; + } + } } else { diff --git a/projects/rocprim/rocprim/include/rocprim/warp/detail/warp_scan_dpp.hpp b/projects/rocprim/rocprim/include/rocprim/warp/detail/warp_scan_dpp.hpp index 58ccedc5f1a..3d2edf287b8 100644 --- a/projects/rocprim/rocprim/include/rocprim/warp/detail/warp_scan_dpp.hpp +++ b/projects/rocprim/rocprim/include/rocprim/warp/detail/warp_scan_dpp.hpp @@ -88,16 +88,21 @@ class warp_scan_dpp } #if !ROCPRIM_TARGET_SPIRV - static_assert(VirtualWaveSize <= 32, - "VirtualWaveSize > 32 is not supported without DPP broadcasts"); -#else - if constexpr(VirtualWaveSize > 32) + if constexpr(!ROCPRIM_IS_GENERIC()) { - ROCPRIM_PRINT_ERROR_ONCE( - "VirtualWaveSize > 32 is not supported without DPP broadcasts"); - return; + static_assert(VirtualWaveSize <= 32, + "VirtualWaveSize > 32 is not supported without DPP broadcasts"); } + else #endif + { + if constexpr(VirtualWaveSize > 32) + { + ROCPRIM_PRINT_ERROR_ONCE( + "VirtualWaveSize > 32 is not supported without DPP broadcasts"); + return; + } + } } else { diff --git a/projects/rocprim/test/rocprim/test_config_dispatch.cpp b/projects/rocprim/test/rocprim/test_config_dispatch.cpp index 9a10e590800..7d39aaf5081 100644 --- a/projects/rocprim/test/rocprim/test_config_dispatch.cpp +++ b/projects/rocprim/test/rocprim/test_config_dispatch.cpp @@ -36,7 +36,14 @@ void write_target_arch([[maybe_unused]] target_arch host_arch, int* __restrict__ #if !defined(ROCPRIM_TARGET_SPIRV) static constexpr auto arch = rocprim::detail::device_target_arch(); - *result = arch == host_arch; + if constexpr(!ROCPRIM_IS_GENERIC()) + { + *result = arch == host_arch; + } + else + { + *result = -2; + } #else *result = -1; #endif @@ -94,8 +101,15 @@ TEST(RocprimConfigDispatchTests, HostMatchesDevice) if(result != -1) { - ASSERT_NE(host_arch, target_arch::invalid); - ASSERT_EQ(result, 1); + if(result != -2) + { + ASSERT_NE(host_arch, target_arch::invalid); + ASSERT_EQ(result, 1); + } + else + { + GTEST_SKIP() << "Generic build: result is null; skipping arch match assertion."; + } } else { From 79679b6ae739b8fe99efe54a3531dfb15dfeab19 Mon Sep 17 00:00:00 2001 From: Nick Breed Date: Thu, 11 Dec 2025 08:07:44 +0000 Subject: [PATCH 20/26] Manually fixing the worst regression after fixing predicate_flag --- .../device/detail/config/device_select_predicated_flag.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_select_predicated_flag.hpp b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_select_predicated_flag.hpp index 63f1ecf26a4..cf346fd84d6 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_select_predicated_flag.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/detail/config/device_select_predicated_flag.hpp @@ -1474,7 +1474,7 @@ constexpr auto select_predicated_flag_config_picker() -> std::enable_if_t< && (sizeof(flag_type) <= 2) && (sizeof(flag_type) > 1))) { return partition_config_params{ - {512, 18} + {256, 24} }; } // Based on data_type = rocprim::half, flag_type = int8_t @@ -1650,7 +1650,7 @@ constexpr auto select_predicated_flag_config_picker() -> std::enable_if_t< && (sizeof(flag_type) > 1))) { return partition_config_params{ - {512, 18} + {256, 18} }; } // Based on data_type = short, flag_type = int8_t From c9fcc81c1205bdc1f0b22f06248ff1b36fd96941 Mon Sep 17 00:00:00 2001 From: Nick Breed Date: Thu, 11 Dec 2025 08:28:41 +0000 Subject: [PATCH 21/26] Add more arch for configs --- .../rocprim/include/rocprim/device/config_types.hpp | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp b/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp index ed3b78e359b..6e928c080e6 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp @@ -174,6 +174,7 @@ enum class target_arch : unsigned int gfx950 = 950, gfx1030 = 1030, gfx1100 = 1100, + gfx1101 = 1101, gfx1102 = 1102, gfx1152 = 1152, gfx1153 = 1153, @@ -200,6 +201,7 @@ enum class gen cdna4, rdna2, rdna3, + rdna3_5, rdna4, }; @@ -234,10 +236,14 @@ constexpr gen gen_from_target_arch(target_arch i) case target_arch::gfx950: return gen::cdna4; case target_arch::gfx1030: return gen::rdna2; case target_arch::gfx1100: + case target_arch::gfx1101: case target_arch::gfx1102: return gen::rdna3; + case target_arch::gfx1152: + case target_arch::gfx1153: return gen::rdna3_5; case target_arch::gfx1200: case target_arch::gfx1201: return gen::rdna4; - default: return gen::unknown; + case target_arch::unknown: + case target_arch::invalid: return gen::unknown; } } @@ -349,6 +355,7 @@ constexpr auto target_arch_descriptors = std::array{ X(gfx950), X(gfx1030), X(gfx1100), + X(gfx1101), X(gfx1102), X(gfx1152), X(gfx1153), @@ -395,6 +402,7 @@ constexpr arch::wavefront::target gen_wavefront_size(const gen gen) case gen::cdna4: return arch::wavefront::target::size64; case gen::rdna2: case gen::rdna3: + case gen::rdna3_5: case gen::rdna4: return arch::wavefront::target::size32; } } @@ -688,6 +696,8 @@ constexpr gpu get_target_gpu_from_name(std::string_view name) { for(const auto& each : target_gpu_names) { + // Look for a substring in the marketing name, e.g., + // "RX 7900" in "AMD Radeon RX 7900 XTX". if(name.find(std::get<0>(each)) != name.npos) { return std::get<1>(each); From f23bea4f87b2b0d606f837bb41dbfc2b375d19eb Mon Sep 17 00:00:00 2001 From: Nick Breed Date: Tue, 16 Dec 2025 07:57:04 +0000 Subject: [PATCH 22/26] Add more supported architectures --- .../rocprim/include/rocprim/config.hpp | 1 + .../include/rocprim/device/config_types.hpp | 23 ++++++++-- .../apply_config_improvements.py | 19 +++++++- .../scripts/autotune/create_optimization.py | 44 ++++++++++++++++++- 4 files changed, 80 insertions(+), 7 deletions(-) diff --git a/projects/rocprim/rocprim/include/rocprim/config.hpp b/projects/rocprim/rocprim/include/rocprim/config.hpp index 2b9c5d40ddc..2a40cdd3544 100644 --- a/projects/rocprim/rocprim/include/rocprim/config.hpp +++ b/projects/rocprim/rocprim/include/rocprim/config.hpp @@ -169,6 +169,7 @@ #define ROCPRIM_IS_RDNA3() \ (__builtin_amdgcn_processor_is("gfx1100") || __builtin_amdgcn_processor_is("gfx1101") \ || __builtin_amdgcn_processor_is("gfx1102") || __builtin_amdgcn_processor_is("gfx1103") \ + || __builtin_amdgcn_processor_is("gfx1150") || __builtin_amdgcn_processor_is("gfx1151") \ || __builtin_amdgcn_processor_is("gfx1152") || __builtin_amdgcn_processor_is("gfx1153") \ || __builtin_amdgcn_processor_is("gfx11-generic")) #define ROCPRIM_IS_RDNA2() \ diff --git a/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp b/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp index 6e928c080e6..bd710d920a6 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp @@ -172,10 +172,16 @@ enum class target_arch : unsigned int gfx90a = 910, gfx942 = 942, gfx950 = 950, + gfx1010 = 1010, + gfx1011 = 1011, + gfx1012 = 1012, gfx1030 = 1030, gfx1100 = 1100, gfx1101 = 1101, gfx1102 = 1102, + gfx1103 = 1103, + gfx1150 = 1150, + gfx1151 = 1151, gfx1152 = 1152, gfx1153 = 1153, gfx1200 = 1200, @@ -199,9 +205,9 @@ enum class gen cdna2, cdna3, cdna4, + rdna1, rdna2, rdna3, - rdna3_5, rdna4, }; @@ -234,12 +240,18 @@ constexpr gen gen_from_target_arch(target_arch i) case target_arch::gfx90a: return gen::cdna2; case target_arch::gfx942: return gen::cdna3; case target_arch::gfx950: return gen::cdna4; + case target_arch::gfx1010: + case target_arch::gfx1011: + case target_arch::gfx1012: return gen::rdna1; case target_arch::gfx1030: return gen::rdna2; case target_arch::gfx1100: case target_arch::gfx1101: - case target_arch::gfx1102: return gen::rdna3; + case target_arch::gfx1102: + case target_arch::gfx1103: + case target_arch::gfx1150: + case target_arch::gfx1151: case target_arch::gfx1152: - case target_arch::gfx1153: return gen::rdna3_5; + case target_arch::gfx1153: return gen::rdna3; case target_arch::gfx1200: case target_arch::gfx1201: return gen::rdna4; case target_arch::unknown: @@ -357,6 +369,9 @@ constexpr auto target_arch_descriptors = std::array{ X(gfx1100), X(gfx1101), X(gfx1102), + X(gfx1103), + X(gfx1150), + X(gfx1151), X(gfx1152), X(gfx1153), X(gfx1200), @@ -400,9 +415,9 @@ constexpr arch::wavefront::target gen_wavefront_size(const gen gen) case gen::cdna2: case gen::cdna3: case gen::cdna4: return arch::wavefront::target::size64; + case gen::rdna1: case gen::rdna2: case gen::rdna3: - case gen::rdna3_5: case gen::rdna4: return arch::wavefront::target::size32; } } diff --git a/projects/rocprim/scripts/apply_config_improvements/apply_config_improvements.py b/projects/rocprim/scripts/apply_config_improvements/apply_config_improvements.py index f60f86d9f5c..9c5884e300a 100644 --- a/projects/rocprim/scripts/apply_config_improvements/apply_config_improvements.py +++ b/projects/rocprim/scripts/apply_config_improvements/apply_config_improvements.py @@ -68,12 +68,29 @@ def get_gen_from_architecture(arch): return "gen::cdna3" case "target_arch::gfx950": return "gen::cdna4" + case ( + "target_arch::gfx1010" + | "target_arch::gfx1011" + | "target_arch::gfx1012" + ): + return "gen::rdna1" case "target_arch::gfx1030": return "gen::rdna2" - case "target_arch::gfx1100" | "target_arch::gfx1102": + case ( + "target_arch::gfx1100" + | "target_arch::gfx1101" + | "target_arch::gfx1102" + | "target_arch::gfx1103" + | "target_arch::gfx1150" + | "target_arch::gfx1151" + | "target_arch::gfx1152" + | "target_arch::gfx1153" + ): return "gen::rdna3" case "target_arch::gfx1200" | "target_arch::gfx1201": return "gen::rdna4" + case "target_arch::unknown" | "target_arch::invalid": + return "gen::unknown" case _: return "gen::unknown" diff --git a/projects/rocprim/scripts/autotune/create_optimization.py b/projects/rocprim/scripts/autotune/create_optimization.py index 6d00a63e74a..b7ff6b07ff4 100755 --- a/projects/rocprim/scripts/autotune/create_optimization.py +++ b/projects/rocprim/scripts/autotune/create_optimization.py @@ -41,7 +41,30 @@ from typing import Dict, List, Callable, Optional, Tuple from jinja2 import Environment, PackageLoader, select_autoescape -TARGET_ARCHITECTURES = ['gfx803', 'gfx900', 'gfx906', 'gfx908', 'gfx90a', 'gfx942', 'gfx1030', 'gfx1100', 'gfx1102', 'gfx1201'] +TARGET_ARCHITECTURES = [ + "gfx803", + "gfx900", + "gfx906", + "gfx908", + "gfx90a", + "gfx942", + "gfx950", + "gfx1010", + "gfx1011", + "gfx1012", + "gfx1030", + "gfx1100", + "gfx1101", + "gfx1102", + "gfx1103", + "gfx1150", + "gfx1151", + "gfx1152", + "gfx1153", + "gfx1200", + "gfx1201", +] + TARGET_GPUS_DICT = { "MI350X": "mi350x", "MI325X": "mi325x", @@ -896,12 +919,29 @@ def __get_gen_from_architecture(self, arch): return "gen::cdna3" case "target_arch::gfx950": return "gen::cdna4" + case ( + "target_arch::gfx1010" + | "target_arch::gfx1011" + | "target_arch::gfx1012" + ): + return "gen::rdna1" case "target_arch::gfx1030": return "gen::rdna2" - case "target_arch::gfx1100" | "target_arch::gfx1102": + case ( + "target_arch::gfx1100" + | "target_arch::gfx1101" + | "target_arch::gfx1102" + | "target_arch::gfx1103" + | "target_arch::gfx1150" + | "target_arch::gfx1151" + | "target_arch::gfx1152" + | "target_arch::gfx1153" + ): return "gen::rdna3" case "target_arch::gfx1200" | "target_arch::gfx1201": return "gen::rdna4" + case "target_arch::unknown" | "target_arch::invalid": + return "gen::unknown" case _: return "gen::unknown" From 6dd20f0f0bd0179449e4724d32f818cf35e4beea Mon Sep 17 00:00:00 2001 From: Nick Breed Date: Tue, 16 Dec 2025 10:17:49 +0000 Subject: [PATCH 23/26] Scope the define to rocprim --- .../include/rocprim/device/config_types.hpp | 44 ++++++++++--------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp b/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp index bd710d920a6..b8ee3be8995 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp @@ -356,28 +356,32 @@ struct target_arch_descriptor const char *arch_name; }; -#define X(ID) target_arch_descriptor{target_arch::ID, #ID} +#define ROCPRIM_DEF_ARCH(ID) \ + target_arch_descriptor \ + { \ + target_arch::ID, #ID \ + } constexpr auto target_arch_descriptors = std::array{ - X(gfx803), - X(gfx900), - X(gfx906), - X(gfx908), - X(gfx90a), - X(gfx942), - X(gfx950), - X(gfx1030), - X(gfx1100), - X(gfx1101), - X(gfx1102), - X(gfx1103), - X(gfx1150), - X(gfx1151), - X(gfx1152), - X(gfx1153), - X(gfx1200), - X(gfx1201), + ROCPRIM_DEF_ARCH(gfx803), + ROCPRIM_DEF_ARCH(gfx900), + ROCPRIM_DEF_ARCH(gfx906), + ROCPRIM_DEF_ARCH(gfx908), + ROCPRIM_DEF_ARCH(gfx90a), + ROCPRIM_DEF_ARCH(gfx942), + ROCPRIM_DEF_ARCH(gfx950), + ROCPRIM_DEF_ARCH(gfx1030), + ROCPRIM_DEF_ARCH(gfx1100), + ROCPRIM_DEF_ARCH(gfx1101), + ROCPRIM_DEF_ARCH(gfx1102), + ROCPRIM_DEF_ARCH(gfx1103), + ROCPRIM_DEF_ARCH(gfx1150), + ROCPRIM_DEF_ARCH(gfx1151), + ROCPRIM_DEF_ARCH(gfx1152), + ROCPRIM_DEF_ARCH(gfx1153), + ROCPRIM_DEF_ARCH(gfx1200), + ROCPRIM_DEF_ARCH(gfx1201), }; -#undef X +#undef ROCPRIM_DEF_ARCH constexpr target_arch get_target_arch_from_name(const char* const arch_name, const std::size_t n) { From f4be2317ce63fab95ae5454e43172eb52dfc0e4c Mon Sep 17 00:00:00 2001 From: Nick Breed Date: Thu, 18 Dec 2025 14:49:16 +0000 Subject: [PATCH 24/26] Add temp fix for failing test --- .../include/rocprim/device/config_types.hpp | 38 ++++--------------- .../rocprim/device/device_histogram.hpp | 9 ++++- 2 files changed, 14 insertions(+), 33 deletions(-) diff --git a/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp b/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp index b8ee3be8995..65d938f5a1b 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp @@ -362,24 +362,13 @@ struct target_arch_descriptor target_arch::ID, #ID \ } constexpr auto target_arch_descriptors = std::array{ - ROCPRIM_DEF_ARCH(gfx803), - ROCPRIM_DEF_ARCH(gfx900), - ROCPRIM_DEF_ARCH(gfx906), - ROCPRIM_DEF_ARCH(gfx908), - ROCPRIM_DEF_ARCH(gfx90a), - ROCPRIM_DEF_ARCH(gfx942), - ROCPRIM_DEF_ARCH(gfx950), - ROCPRIM_DEF_ARCH(gfx1030), - ROCPRIM_DEF_ARCH(gfx1100), - ROCPRIM_DEF_ARCH(gfx1101), - ROCPRIM_DEF_ARCH(gfx1102), - ROCPRIM_DEF_ARCH(gfx1103), - ROCPRIM_DEF_ARCH(gfx1150), - ROCPRIM_DEF_ARCH(gfx1151), - ROCPRIM_DEF_ARCH(gfx1152), - ROCPRIM_DEF_ARCH(gfx1153), - ROCPRIM_DEF_ARCH(gfx1200), - ROCPRIM_DEF_ARCH(gfx1201), + ROCPRIM_DEF_ARCH(gfx803), ROCPRIM_DEF_ARCH(gfx900), ROCPRIM_DEF_ARCH(gfx906), + ROCPRIM_DEF_ARCH(gfx908), ROCPRIM_DEF_ARCH(gfx90a), ROCPRIM_DEF_ARCH(gfx942), + ROCPRIM_DEF_ARCH(gfx950), ROCPRIM_DEF_ARCH(gfx1010), ROCPRIM_DEF_ARCH(gfx1011), + ROCPRIM_DEF_ARCH(gfx1012), ROCPRIM_DEF_ARCH(gfx1030), ROCPRIM_DEF_ARCH(gfx1100), + ROCPRIM_DEF_ARCH(gfx1101), ROCPRIM_DEF_ARCH(gfx1102), ROCPRIM_DEF_ARCH(gfx1103), + ROCPRIM_DEF_ARCH(gfx1150), ROCPRIM_DEF_ARCH(gfx1151), ROCPRIM_DEF_ARCH(gfx1152), + ROCPRIM_DEF_ARCH(gfx1153), ROCPRIM_DEF_ARCH(gfx1200), ROCPRIM_DEF_ARCH(gfx1201), }; #undef ROCPRIM_DEF_ARCH @@ -395,19 +384,6 @@ constexpr target_arch get_target_arch_from_name(const char* const arch_name, con return target_arch::unknown; } -template -constexpr void for_each_arch_impl(F&& f, std::index_sequence) -{ - (f(std::integral_constant{}), ...); -} - -template -constexpr void for_each_arch(F&& f) -{ - for_each_arch_impl(std::forward(f), - std::make_index_sequence{}); -} - constexpr arch::wavefront::target gen_wavefront_size(const gen gen) { switch(gen) diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_histogram.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_histogram.hpp index b4ac6e6b4ed..5dd1f5343c6 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_histogram.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_histogram.hpp @@ -184,8 +184,13 @@ inline hipError_t histogram_impl(void* temporary_storage, max_bins = std::max(max_bins, bins[channel]); } - const bool use_shared_mem = total_shared_bins <= shared_impl_max_bins; - const bool use_private_histogram = target_arch == target_arch::gfx942; + hipStreamCaptureStatus status; + ROCPRIM_RETURN_ON_ERROR(hipStreamIsCapturing(stream, &status)); + + const bool use_shared_mem = total_shared_bins <= shared_impl_max_bins; + // TEMP FIX: disable optimization when using hipgraphs. + const bool use_private_histogram + = target_arch == target_arch::gfx942 && hipStreamCaptureStatusActive != status; Counter* private_histograms = nullptr; unsigned int* block_id_count = nullptr; From 405ded6b80bdcb0b76bd29d8eec485303cb9e336 Mon Sep 17 00:00:00 2001 From: Nick Breed Date: Fri, 19 Dec 2025 11:30:08 +0000 Subject: [PATCH 25/26] TEMP FIX: instead of disabling optimization for array size one larger --- .../include/rocprim/device/config_types.hpp | 19 +++++++++++-------- .../rocprim/device/device_histogram.hpp | 9 ++------- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp b/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp index 65d938f5a1b..35c5ea55d9c 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp @@ -361,14 +361,17 @@ struct target_arch_descriptor { \ target_arch::ID, #ID \ } -constexpr auto target_arch_descriptors = std::array{ - ROCPRIM_DEF_ARCH(gfx803), ROCPRIM_DEF_ARCH(gfx900), ROCPRIM_DEF_ARCH(gfx906), - ROCPRIM_DEF_ARCH(gfx908), ROCPRIM_DEF_ARCH(gfx90a), ROCPRIM_DEF_ARCH(gfx942), - ROCPRIM_DEF_ARCH(gfx950), ROCPRIM_DEF_ARCH(gfx1010), ROCPRIM_DEF_ARCH(gfx1011), - ROCPRIM_DEF_ARCH(gfx1012), ROCPRIM_DEF_ARCH(gfx1030), ROCPRIM_DEF_ARCH(gfx1100), - ROCPRIM_DEF_ARCH(gfx1101), ROCPRIM_DEF_ARCH(gfx1102), ROCPRIM_DEF_ARCH(gfx1103), - ROCPRIM_DEF_ARCH(gfx1150), ROCPRIM_DEF_ARCH(gfx1151), ROCPRIM_DEF_ARCH(gfx1152), - ROCPRIM_DEF_ARCH(gfx1153), ROCPRIM_DEF_ARCH(gfx1200), ROCPRIM_DEF_ARCH(gfx1201), +// TEMP FIX: The size of the array should be 1 larger then the amount of items. +constexpr std::array target_arch_descriptors = { + { + ROCPRIM_DEF_ARCH(gfx803), ROCPRIM_DEF_ARCH(gfx900), ROCPRIM_DEF_ARCH(gfx906), + ROCPRIM_DEF_ARCH(gfx908), ROCPRIM_DEF_ARCH(gfx90a), ROCPRIM_DEF_ARCH(gfx942), + ROCPRIM_DEF_ARCH(gfx950), ROCPRIM_DEF_ARCH(gfx1010), ROCPRIM_DEF_ARCH(gfx1011), + ROCPRIM_DEF_ARCH(gfx1012), ROCPRIM_DEF_ARCH(gfx1030), ROCPRIM_DEF_ARCH(gfx1100), + ROCPRIM_DEF_ARCH(gfx1101), ROCPRIM_DEF_ARCH(gfx1102), ROCPRIM_DEF_ARCH(gfx1103), + ROCPRIM_DEF_ARCH(gfx1150), ROCPRIM_DEF_ARCH(gfx1151), ROCPRIM_DEF_ARCH(gfx1152), + ROCPRIM_DEF_ARCH(gfx1153), ROCPRIM_DEF_ARCH(gfx1200), ROCPRIM_DEF_ARCH(gfx1201), + } }; #undef ROCPRIM_DEF_ARCH diff --git a/projects/rocprim/rocprim/include/rocprim/device/device_histogram.hpp b/projects/rocprim/rocprim/include/rocprim/device/device_histogram.hpp index 5dd1f5343c6..b4ac6e6b4ed 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/device_histogram.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/device_histogram.hpp @@ -184,13 +184,8 @@ inline hipError_t histogram_impl(void* temporary_storage, max_bins = std::max(max_bins, bins[channel]); } - hipStreamCaptureStatus status; - ROCPRIM_RETURN_ON_ERROR(hipStreamIsCapturing(stream, &status)); - - const bool use_shared_mem = total_shared_bins <= shared_impl_max_bins; - // TEMP FIX: disable optimization when using hipgraphs. - const bool use_private_histogram - = target_arch == target_arch::gfx942 && hipStreamCaptureStatusActive != status; + const bool use_shared_mem = total_shared_bins <= shared_impl_max_bins; + const bool use_private_histogram = target_arch == target_arch::gfx942; Counter* private_histograms = nullptr; unsigned int* block_id_count = nullptr; From 984c824839bf1fedfe220b6b1b29bfe817c83d94 Mon Sep 17 00:00:00 2001 From: Nick Breed Date: Mon, 5 Jan 2026 12:36:05 +0000 Subject: [PATCH 26/26] Replace workaround with less undefined fix --- .../include/rocprim/device/config_types.hpp | 59 +++++++++---------- 1 file changed, 28 insertions(+), 31 deletions(-) diff --git a/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp b/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp index 35c5ea55d9c..34e66f5f26b 100644 --- a/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp +++ b/projects/rocprim/rocprim/include/rocprim/device/config_types.hpp @@ -161,6 +161,7 @@ using default_or_custom_config = typename std::conditional::value, Default, Config>::type; #ifndef DOXYGEN_SHOULD_SKIP_THIS +// NOTE: When adding a new target_arch also add it to gen_from_target_arch and get_target_arch_from_name enum class target_arch : unsigned int { // This must be zero, to initialize the device -> architecture cache @@ -350,42 +351,38 @@ constexpr bool prefix_equals(const char* lhs, const char* rhs, std::size_t n) return i == n && *lhs == '\0'; } -struct target_arch_descriptor -{ - target_arch arch; - const char *arch_name; -}; - -#define ROCPRIM_DEF_ARCH(ID) \ - target_arch_descriptor \ - { \ - target_arch::ID, #ID \ +#define ROCPRIM_RETURN_IF_ARCH(ID) \ + if(prefix_equals(#ID, arch_name, n)) \ + { \ + return target_arch::ID; \ } -// TEMP FIX: The size of the array should be 1 larger then the amount of items. -constexpr std::array target_arch_descriptors = { - { - ROCPRIM_DEF_ARCH(gfx803), ROCPRIM_DEF_ARCH(gfx900), ROCPRIM_DEF_ARCH(gfx906), - ROCPRIM_DEF_ARCH(gfx908), ROCPRIM_DEF_ARCH(gfx90a), ROCPRIM_DEF_ARCH(gfx942), - ROCPRIM_DEF_ARCH(gfx950), ROCPRIM_DEF_ARCH(gfx1010), ROCPRIM_DEF_ARCH(gfx1011), - ROCPRIM_DEF_ARCH(gfx1012), ROCPRIM_DEF_ARCH(gfx1030), ROCPRIM_DEF_ARCH(gfx1100), - ROCPRIM_DEF_ARCH(gfx1101), ROCPRIM_DEF_ARCH(gfx1102), ROCPRIM_DEF_ARCH(gfx1103), - ROCPRIM_DEF_ARCH(gfx1150), ROCPRIM_DEF_ARCH(gfx1151), ROCPRIM_DEF_ARCH(gfx1152), - ROCPRIM_DEF_ARCH(gfx1153), ROCPRIM_DEF_ARCH(gfx1200), ROCPRIM_DEF_ARCH(gfx1201), - } -}; -#undef ROCPRIM_DEF_ARCH - constexpr target_arch get_target_arch_from_name(const char* const arch_name, const std::size_t n) { - for (const auto& desc : target_arch_descriptors) - { - if(prefix_equals(desc.arch_name, arch_name, n)) - { - return desc.arch; - } - } + ROCPRIM_RETURN_IF_ARCH(gfx803); + ROCPRIM_RETURN_IF_ARCH(gfx900); + ROCPRIM_RETURN_IF_ARCH(gfx906); + ROCPRIM_RETURN_IF_ARCH(gfx908); + ROCPRIM_RETURN_IF_ARCH(gfx90a); + ROCPRIM_RETURN_IF_ARCH(gfx942); + ROCPRIM_RETURN_IF_ARCH(gfx950); + ROCPRIM_RETURN_IF_ARCH(gfx1010); + ROCPRIM_RETURN_IF_ARCH(gfx1011); + ROCPRIM_RETURN_IF_ARCH(gfx1012); + ROCPRIM_RETURN_IF_ARCH(gfx1030); + ROCPRIM_RETURN_IF_ARCH(gfx1100); + ROCPRIM_RETURN_IF_ARCH(gfx1101); + ROCPRIM_RETURN_IF_ARCH(gfx1102); + ROCPRIM_RETURN_IF_ARCH(gfx1103); + ROCPRIM_RETURN_IF_ARCH(gfx1150); + ROCPRIM_RETURN_IF_ARCH(gfx1151); + ROCPRIM_RETURN_IF_ARCH(gfx1152); + ROCPRIM_RETURN_IF_ARCH(gfx1153); + ROCPRIM_RETURN_IF_ARCH(gfx1200); + ROCPRIM_RETURN_IF_ARCH(gfx1201); + return target_arch::unknown; } +#undef ROCPRIM_RETURN_IF_ARCH constexpr arch::wavefront::target gen_wavefront_size(const gen gen) {