From 078ec9def4fb89b41a340e5e08c5cbadcd4620da Mon Sep 17 00:00:00 2001 From: Sami Remes Date: Mon, 20 Oct 2025 14:12:57 +0000 Subject: [PATCH 1/5] Add sigmoid option to topk_softmax --- .../ck_tile/09_topk_softmax/topk_softmax.cpp | 45 +++++++- .../09_topk_softmax/topk_softmax_api.cpp | 104 +++++++++++++++--- .../09_topk_softmax/topk_softmax_api.hpp | 1 + .../kernel/topk_softmax_kernel.hpp | 4 +- .../topk_softmax_warp_per_row_pipeline.hpp | 20 +++- .../topk_softmax_warp_per_row_problem.hpp | 3 + .../topk_softmax/test_topk_softmax.hpp | 51 ++++++++- .../topk_softmax/test_topk_softmax_api.cpp | 104 +++++++++++++++--- .../topk_softmax/test_topk_softmax_api.hpp | 1 + 9 files changed, 282 insertions(+), 51 deletions(-) diff --git a/example/ck_tile/09_topk_softmax/topk_softmax.cpp b/example/ck_tile/09_topk_softmax/topk_softmax.cpp index 0487bd05d23..400329986a4 100644 --- a/example/ck_tile/09_topk_softmax/topk_softmax.cpp +++ b/example/ck_tile/09_topk_softmax/topk_softmax.cpp @@ -83,6 +83,26 @@ auto reference_topk_softmax(const ck_tile::HostTensor& x, reference_topk(y, y_values, y_indices, k, dim, largest, sorted); } +template +auto reference_topk_sigmoid(const ck_tile::HostTensor& x, + ck_tile::HostTensor& y_values, + ck_tile::HostTensor& y_indices, + ck_tile::index_t k, + ck_tile::index_t dim = -1, + bool largest = true, + bool sorted = true) +{ + using namespace ck_tile; + + // topk only - no need to apply the sigmoid first + auto x_fp32 = x.template CopyAsType(); + reference_topk(x_fp32, y_values, y_indices, k, dim, largest, sorted); + // apply sigmoid + std::transform(y_values.begin(), y_values.end(), y_values.begin(), [](auto value) { + return WeightType(1) / (WeightType(1) + exp(-value)); + }); +} + // different threshold for different dtype template auto get_elimit(std::string /*init_method*/) @@ -133,7 +153,8 @@ auto create_args(int argc, char* argv[]) .insert("warmup", "5", "number of iterations before benchmark the kernel") .insert("repeat", "20", "number of iterations to benchmark the kernel") .insert("json", "0", "0: No Json, 1: Dump Results in Json format") - .insert("jsonfile", "topk_softmax.json", "json file name to dump results"); + .insert("jsonfile", "topk_softmax.json", "json file name to dump results") + .insert("activation", "softmax", "activation function to use: softmax or sigmoid"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -154,6 +175,7 @@ bool test_topk_softmax(ck_tile::ArgParser args) int kname = args.get_int("kname"); int warmup = args.get_int("warmup"); int repeat = args.get_int("repeat"); + std::string activation = args.get_str("activation"); if(stride_input < 0) { @@ -204,7 +226,7 @@ bool test_topk_softmax(ck_tile::ArgParser args) x_dev.ToDevice(x_host.data()); - topk_softmax_trait trait{input_prec, weight_prec, experts}; + topk_softmax_trait trait{input_prec, weight_prec, experts, activation}; topk_softmax_kargs karg{x_dev.GetDeviceBuffer(), value_dev.GetDeviceBuffer(), @@ -221,7 +243,7 @@ bool test_topk_softmax(ck_tile::ArgParser args) warmup, repeat}; auto ms = topk_softmax(trait, karg, sc); - printf("[%s|%s]tokens:%d, experts:%d, topk:%d, st_i:%d, st_o:%d, ms:%f, ", + printf("[%s|%s]tokens:%d, experts:%d, topk:%d, st_i:%d, st_o:%d, activation:%s, ms:%f, ", input_prec.c_str(), weight_prec.c_str(), tokens, @@ -229,6 +251,7 @@ bool test_topk_softmax(ck_tile::ArgParser args) topk, stride_input, stride_output, + activation.c_str(), ms); if(ms < 0) printf("not supported\n"); @@ -247,8 +270,20 @@ bool test_topk_softmax(ck_tile::ArgParser args) ck_tile::HostTensor value_ref({tokens, topk}, {stride_output, 1}); ck_tile::HostTensor index_ref({tokens, topk}, {stride_output, 1}); - reference_topk_softmax( - x_host, value_ref, index_ref, topk); + if(activation == "softmax") + { + reference_topk_softmax( + x_host, value_ref, index_ref, topk); + } + else if(activation == "sigmoid") + { + reference_topk_sigmoid( + x_host, value_ref, index_ref, topk); + } + else + { + throw std::runtime_error("unsupported activation type: " + activation); + } auto [rtol, atol] = get_elimit(""); for(int i_t = 0; i_t < tokens; i_t++) diff --git a/example/ck_tile/09_topk_softmax/topk_softmax_api.cpp b/example/ck_tile/09_topk_softmax/topk_softmax_api.cpp index 6e6bb20020c..bcc04916e6b 100644 --- a/example/ck_tile/09_topk_softmax/topk_softmax_api.cpp +++ b/example/ck_tile/09_topk_softmax/topk_softmax_api.cpp @@ -3,10 +3,11 @@ #include "topk_softmax_api.hpp" -#define TOPK_SOFTMAX_DISPATCH(experts_) \ +#define TOPK_SOFTMAX_DISPATCH(experts_, use_softmax_) \ constexpr ck_tile::index_t ts_experts = experts_; \ + constexpr bool ts_use_softmax = use_softmax_; \ using ts_problem = ck_tile:: \ - TopkSoftmaxWarpPerRowProblem; \ + TopkSoftmaxWarpPerRowProblem; \ using ts_pipeline = ck_tile::TopkSoftmaxWarpPerRowPipeline; \ \ using kernel = ck_tile::TopkSoftmaxKernel; \ @@ -23,7 +24,7 @@ float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_config s) { - if(t.input_type == "fp16" && t.weight_type == "fp32") + if(t.input_type == "fp16" && t.weight_type == "fp32" && t.activation == "softmax") { using ts_input_type = ck_tile::fp16_t; using ts_weight_type = float; @@ -31,36 +32,36 @@ float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_c #if 1 if(t.experts <= 8) { - TOPK_SOFTMAX_DISPATCH(8) + TOPK_SOFTMAX_DISPATCH(8, true) } else if(t.experts <= 16) { - TOPK_SOFTMAX_DISPATCH(16) + TOPK_SOFTMAX_DISPATCH(16, true) } else if(t.experts <= 32) { - TOPK_SOFTMAX_DISPATCH(32) + TOPK_SOFTMAX_DISPATCH(32, true) } else if(t.experts <= 64) { - TOPK_SOFTMAX_DISPATCH(64) + TOPK_SOFTMAX_DISPATCH(64, true) } else if(t.experts <= 128) { - TOPK_SOFTMAX_DISPATCH(128) + TOPK_SOFTMAX_DISPATCH(128, true) } else if(t.experts <= 192) { - TOPK_SOFTMAX_DISPATCH(192) + TOPK_SOFTMAX_DISPATCH(192, true) } #else if(t.experts <= 128) { - TOPK_SOFTMAX_DISPATCH(128) + TOPK_SOFTMAX_DISPATCH(128, true) } #endif } - else if(t.input_type == "bf16" && t.weight_type == "fp32") + else if(t.input_type == "bf16" && t.weight_type == "fp32" && t.activation == "softmax") { #if 1 using ts_input_type = ck_tile::bf16_t; @@ -68,27 +69,96 @@ float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_c using ts_index_type = ck_tile::index_t; if(t.experts <= 8) { - TOPK_SOFTMAX_DISPATCH(8) + TOPK_SOFTMAX_DISPATCH(8, true) } else if(t.experts <= 16) { - TOPK_SOFTMAX_DISPATCH(16) + TOPK_SOFTMAX_DISPATCH(16, true) } else if(t.experts <= 32) { - TOPK_SOFTMAX_DISPATCH(32) + TOPK_SOFTMAX_DISPATCH(32, true) } else if(t.experts <= 64) { - TOPK_SOFTMAX_DISPATCH(64) + TOPK_SOFTMAX_DISPATCH(64, true) } else if(t.experts <= 128) { - TOPK_SOFTMAX_DISPATCH(128) + TOPK_SOFTMAX_DISPATCH(128, true) } else if(t.experts <= 192) { - TOPK_SOFTMAX_DISPATCH(192) + TOPK_SOFTMAX_DISPATCH(192, true) + } +#endif + } + if(t.input_type == "fp16" && t.weight_type == "fp32" && t.activation == "sigmoid") + { + using ts_input_type = ck_tile::fp16_t; + using ts_weight_type = float; + using ts_index_type = ck_tile::index_t; +#if 1 + if(t.experts <= 8) + { + TOPK_SOFTMAX_DISPATCH(8, false) + } + else if(t.experts <= 16) + { + TOPK_SOFTMAX_DISPATCH(16, false) + } + else if(t.experts <= 32) + { + TOPK_SOFTMAX_DISPATCH(32, false) + } + else if(t.experts <= 64) + { + TOPK_SOFTMAX_DISPATCH(64, false) + } + else if(t.experts <= 128) + { + TOPK_SOFTMAX_DISPATCH(128, false) + } + else if(t.experts <= 192) + { + TOPK_SOFTMAX_DISPATCH(192, false) + } +#else + if(t.experts <= 128) + { + TOPK_SOFTMAX_DISPATCH(128, false) + } +#endif + } + else if(t.input_type == "bf16" && t.weight_type == "fp32" && t.activation == "sigmoid") + { +#if 1 + using ts_input_type = ck_tile::bf16_t; + using ts_weight_type = float; + using ts_index_type = ck_tile::index_t; + if(t.experts <= 8) + { + TOPK_SOFTMAX_DISPATCH(8, false) + } + else if(t.experts <= 16) + { + TOPK_SOFTMAX_DISPATCH(16, false) + } + else if(t.experts <= 32) + { + TOPK_SOFTMAX_DISPATCH(32, false) + } + else if(t.experts <= 64) + { + TOPK_SOFTMAX_DISPATCH(64, false) + } + else if(t.experts <= 128) + { + TOPK_SOFTMAX_DISPATCH(128, false) + } + else if(t.experts <= 192) + { + TOPK_SOFTMAX_DISPATCH(192, false) } #endif } diff --git a/example/ck_tile/09_topk_softmax/topk_softmax_api.hpp b/example/ck_tile/09_topk_softmax/topk_softmax_api.hpp index 65651efa4d4..c98a887736f 100644 --- a/example/ck_tile/09_topk_softmax/topk_softmax_api.hpp +++ b/example/ck_tile/09_topk_softmax/topk_softmax_api.hpp @@ -12,6 +12,7 @@ struct topk_softmax_trait std::string input_type; std::string weight_type; // currently always float int experts; + std::string activation; // "softmax" or "sigmoid" }; struct topk_softmax_kargs : public ck_tile::TopkSoftmaxHostArgs diff --git a/include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp b/include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp index e8727ea0659..019e940a339 100644 --- a/include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp +++ b/include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp @@ -21,7 +21,7 @@ struct TopkSoftmaxHostArgs index_t num_experts; index_t topk; index_t stride_input; // row stride for input, at least experts - index_t stride_output; // row stride for output/indices, at least tpok + index_t stride_output; // row stride for output/indices, at least topk }; template @@ -45,7 +45,7 @@ struct TopkSoftmaxKernel index_t num_experts; index_t topk; index_t stride_input; // row stride for input, at least experts - index_t stride_output; // row stride for output/indices, at least tpok + index_t stride_output; // row stride for output/indices, at least topk }; using Kargs = TopkSoftmaxKargs; diff --git a/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp b/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp index d620d9bec9c..19fdbf46507 100644 --- a/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp +++ b/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp @@ -90,6 +90,11 @@ struct TopkSoftmaxWarpPerRowPipeline const auto current_expert = x_indices.at(number<1>{}); w_(idx) = current_expert >= experts ? -numeric::infinity() : w_(idx); + if constexpr(!Problem::ActivationIsSoftmax) + { + // sigmoid can be pre-computed already here if not using softmax + w_(idx) = WeightType(1) / (WeightType(1) + exp(-w_(idx))); + } }; tile_sweeper ts{w_, w_f}; ts(); @@ -97,10 +102,17 @@ struct TopkSoftmaxWarpPerRowPipeline #endif }(); - // softmax - auto y = softmax(w); - - topk(y, out_win, idx_win, k); + if constexpr(Problem::ActivationIsSoftmax) + { + auto y = softmax(w); + topk(y, out_win, idx_win, k); + } + else + { + // sigmoid was already pre-computed above, so only do topk now + topk(w, out_win, idx_win, k); + } + // check exit if constexpr(Problem::LaunchType == 0) diff --git a/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp b/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp index 917096ad5e3..3c5f6b328b4 100644 --- a/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp +++ b/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp @@ -13,6 +13,7 @@ template 0, persistent #occupancy @@ -31,6 +32,8 @@ struct TopkSoftmaxWarpPerRowProblem static constexpr index_t BlockSize = BlockSize_; static constexpr index_t WarpSize = get_warp_size(); + static constexpr bool ActivationIsSoftmax = ActivationIsSoftmax_; + static_assert(BytesPerIssue % sizeof(InputType) == 0); static constexpr index_t VectorSize = BytesPerIssue / sizeof(InputType); static_assert(Experts % VectorSize == 0); diff --git a/test/ck_tile/topk_softmax/test_topk_softmax.hpp b/test/ck_tile/topk_softmax/test_topk_softmax.hpp index 1bb400ad07a..cc64aec91ce 100644 --- a/test/ck_tile/topk_softmax/test_topk_softmax.hpp +++ b/test/ck_tile/topk_softmax/test_topk_softmax.hpp @@ -39,6 +39,26 @@ auto reference_topk_softmax(const ck_tile::HostTensor& x, reference_topk(y, y_values, y_indices, k, dim, largest, sorted); } +template +auto reference_topk_sigmoid(const ck_tile::HostTensor& x, + ck_tile::HostTensor& y_values, + ck_tile::HostTensor& y_indices, + ck_tile::index_t k, + ck_tile::index_t dim = -1, + bool largest = true, + bool sorted = true) +{ + using namespace ck_tile; + + // topk only - no need to apply the sigmoid first + auto x_fp32 = x.template CopyAsType(); + reference_topk(x_fp32, y_values, y_indices, k, dim, largest, sorted); + // apply sigmoid + std::transform(y_values.begin(), y_values.end(), y_values.begin(), [](auto value) { + return WeightType(1) / (WeightType(1) + exp(-value)); + }); +} + // different threshold for different dtype template auto get_elimit(std::string /*init_method*/) @@ -87,7 +107,8 @@ auto create_args(int argc, char* argv[]) .insert("seed", "-1", "seed to be used, -1 means random every time") .insert("kname", "0", "when set to 1 it will print kernel name") .insert("warmup", "5", "number of iterations before benchmark the kernel") - .insert("repeat", "20", "number of iterations to benchmark the kernel"); + .insert("repeat", "20", "number of iterations to benchmark the kernel") + .insert("activation", "softmax", "activation function to use: softmax or sigmoid"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); @@ -108,6 +129,7 @@ bool test_topk_softmax(ck_tile::ArgParser args) int kname = args.get_int("kname"); int warmup = args.get_int("warmup"); int repeat = args.get_int("repeat"); + std::string activation = args.get_str("activation"); if(stride_input < 0) { @@ -158,7 +180,7 @@ bool test_topk_softmax(ck_tile::ArgParser args) x_dev.ToDevice(x_host.data()); - topk_softmax_trait trait{input_prec, weight_prec, experts}; + topk_softmax_trait trait{input_prec, weight_prec, experts, activation}; topk_softmax_kargs karg{x_dev.GetDeviceBuffer(), value_dev.GetDeviceBuffer(), @@ -175,7 +197,7 @@ bool test_topk_softmax(ck_tile::ArgParser args) warmup, repeat}; auto ms = topk_softmax(trait, karg, sc); - printf("[%s|%s]tokens:%d, experts:%d, topk:%d, st_i:%d, st_o:%d, ms:%f, ", + printf("[%s|%s]tokens:%d, experts:%d, topk:%d, st_i:%d, st_o:%d, activation:%s, ms:%f, ", input_prec.c_str(), weight_prec.c_str(), tokens, @@ -183,6 +205,7 @@ bool test_topk_softmax(ck_tile::ArgParser args) topk, stride_input, stride_output, + activation.c_str(), ms); if(ms < 0) printf("not supported\n"); @@ -201,8 +224,20 @@ bool test_topk_softmax(ck_tile::ArgParser args) ck_tile::HostTensor value_ref({tokens, topk}, {stride_output, 1}); ck_tile::HostTensor index_ref({tokens, topk}, {stride_output, 1}); - reference_topk_softmax( - x_host, value_ref, index_ref, topk); + if(activation == "softmax") + { + reference_topk_softmax( + x_host, value_ref, index_ref, topk); + } + else if(activation == "sigmoid") + { + reference_topk_sigmoid( + x_host, value_ref, index_ref, topk); + } + else + { + throw std::runtime_error("unsupported activation type: " + activation); + } auto [rtol, atol] = get_elimit(""); for(int i_t = 0; i_t < tokens; i_t++) @@ -255,7 +290,11 @@ int run_gemm_combinations(std::string const& data_type) {"-t=71", "-e=11", "-k=11", "-st_i=30", "-st_o=12"}, {"-t=1", "-e=1", "-k=1"}, {"-t=99", "-e=2", "-k=1", "-st_i=11", "-st_o=5"}, - {"-t=333", "-e=99", "-k=13", "-st_i=191", "-st_o=17"}}; + {"-t=333", "-e=99", "-k=13", "-st_i=191", "-st_o=17"}, + {"-t=20", "-e=5", "-k=2", "-activation=sigmoid"}, + {"-t=220", "-e=9", "-k=3", "-activation=sigmoid"}, + {"-t=500", "-e=21", "-k=13", "-activation=sigmoid"} + }; bool result = true; std::string pr_i = "-pr_i=" + data_type; diff --git a/test/ck_tile/topk_softmax/test_topk_softmax_api.cpp b/test/ck_tile/topk_softmax/test_topk_softmax_api.cpp index 7c90c8200c6..303f7cb10d8 100644 --- a/test/ck_tile/topk_softmax/test_topk_softmax_api.cpp +++ b/test/ck_tile/topk_softmax/test_topk_softmax_api.cpp @@ -3,10 +3,11 @@ #include "test_topk_softmax_api.hpp" -#define TOPK_SOFTMAX_DISPATCH(experts_) \ +#define TOPK_SOFTMAX_DISPATCH(experts_, use_softmax_) \ constexpr ck_tile::index_t ts_experts = experts_; \ + constexpr bool ts_use_softmax = use_softmax_; \ using ts_problem = ck_tile:: \ - TopkSoftmaxWarpPerRowProblem; \ + TopkSoftmaxWarpPerRowProblem; \ using ts_pipeline = ck_tile::TopkSoftmaxWarpPerRowPipeline; \ \ using kernel = ck_tile::TopkSoftmaxKernel; \ @@ -23,7 +24,7 @@ float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_config s) { - if(t.input_type == "fp16" && t.weight_type == "fp32") + if(t.input_type == "fp16" && t.weight_type == "fp32" && t.activation == "softmax") { using ts_input_type = ck_tile::fp16_t; using ts_weight_type = float; @@ -31,36 +32,36 @@ float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_c #if 1 if(t.experts <= 8) { - TOPK_SOFTMAX_DISPATCH(8) + TOPK_SOFTMAX_DISPATCH(8, true) } else if(t.experts <= 16) { - TOPK_SOFTMAX_DISPATCH(16) + TOPK_SOFTMAX_DISPATCH(16, true) } else if(t.experts <= 32) { - TOPK_SOFTMAX_DISPATCH(32) + TOPK_SOFTMAX_DISPATCH(32, true) } else if(t.experts <= 64) { - TOPK_SOFTMAX_DISPATCH(64) + TOPK_SOFTMAX_DISPATCH(64, true) } else if(t.experts <= 128) { - TOPK_SOFTMAX_DISPATCH(128) + TOPK_SOFTMAX_DISPATCH(128, true) } else if(t.experts <= 192) { - TOPK_SOFTMAX_DISPATCH(192) + TOPK_SOFTMAX_DISPATCH(192, true) } #else if(t.experts <= 128) { - TOPK_SOFTMAX_DISPATCH(128) + TOPK_SOFTMAX_DISPATCH(128, true) } #endif } - else if(t.input_type == "bf16" && t.weight_type == "fp32") + else if(t.input_type == "bf16" && t.weight_type == "fp32" && t.activation == "softmax") { #if 1 using ts_input_type = ck_tile::bf16_t; @@ -68,27 +69,96 @@ float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_c using ts_index_type = ck_tile::index_t; if(t.experts <= 8) { - TOPK_SOFTMAX_DISPATCH(8) + TOPK_SOFTMAX_DISPATCH(8, true) } else if(t.experts <= 16) { - TOPK_SOFTMAX_DISPATCH(16) + TOPK_SOFTMAX_DISPATCH(16, true) } else if(t.experts <= 32) { - TOPK_SOFTMAX_DISPATCH(32) + TOPK_SOFTMAX_DISPATCH(32, true) } else if(t.experts <= 64) { - TOPK_SOFTMAX_DISPATCH(64) + TOPK_SOFTMAX_DISPATCH(64, true) } else if(t.experts <= 128) { - TOPK_SOFTMAX_DISPATCH(128) + TOPK_SOFTMAX_DISPATCH(128, true) } else if(t.experts <= 192) { - TOPK_SOFTMAX_DISPATCH(192) + TOPK_SOFTMAX_DISPATCH(192, true) + } +#endif + } + if(t.input_type == "fp16" && t.weight_type == "fp32" && t.activation == "sigmoid") + { + using ts_input_type = ck_tile::fp16_t; + using ts_weight_type = float; + using ts_index_type = ck_tile::index_t; +#if 1 + if(t.experts <= 8) + { + TOPK_SOFTMAX_DISPATCH(8, false) + } + else if(t.experts <= 16) + { + TOPK_SOFTMAX_DISPATCH(16, false) + } + else if(t.experts <= 32) + { + TOPK_SOFTMAX_DISPATCH(32, false) + } + else if(t.experts <= 64) + { + TOPK_SOFTMAX_DISPATCH(64, false) + } + else if(t.experts <= 128) + { + TOPK_SOFTMAX_DISPATCH(128, false) + } + else if(t.experts <= 192) + { + TOPK_SOFTMAX_DISPATCH(192, false) + } +#else + if(t.experts <= 128) + { + TOPK_SOFTMAX_DISPATCH(128, false) + } +#endif + } + else if(t.input_type == "bf16" && t.weight_type == "fp32" && t.activation == "sigmoid") + { +#if 1 + using ts_input_type = ck_tile::bf16_t; + using ts_weight_type = float; + using ts_index_type = ck_tile::index_t; + if(t.experts <= 8) + { + TOPK_SOFTMAX_DISPATCH(8, false) + } + else if(t.experts <= 16) + { + TOPK_SOFTMAX_DISPATCH(16, false) + } + else if(t.experts <= 32) + { + TOPK_SOFTMAX_DISPATCH(32, false) + } + else if(t.experts <= 64) + { + TOPK_SOFTMAX_DISPATCH(64, false) + } + else if(t.experts <= 128) + { + TOPK_SOFTMAX_DISPATCH(128, false) + } + else if(t.experts <= 192) + { + TOPK_SOFTMAX_DISPATCH(192, false) } #endif } diff --git a/test/ck_tile/topk_softmax/test_topk_softmax_api.hpp b/test/ck_tile/topk_softmax/test_topk_softmax_api.hpp index 65651efa4d4..c98a887736f 100644 --- a/test/ck_tile/topk_softmax/test_topk_softmax_api.hpp +++ b/test/ck_tile/topk_softmax/test_topk_softmax_api.hpp @@ -12,6 +12,7 @@ struct topk_softmax_trait std::string input_type; std::string weight_type; // currently always float int experts; + std::string activation; // "softmax" or "sigmoid" }; struct topk_softmax_kargs : public ck_tile::TopkSoftmaxHostArgs From 4ee6677158ea47fc11e461a4bf7da38003048284 Mon Sep 17 00:00:00 2001 From: Sami Remes Date: Mon, 20 Oct 2025 14:14:34 +0000 Subject: [PATCH 2/5] fix formatting --- .../09_topk_softmax/topk_softmax_api.cpp | 37 ++++++++++--------- .../topk_softmax_warp_per_row_pipeline.hpp | 1 - .../topk_softmax_warp_per_row_problem.hpp | 10 ++--- .../topk_softmax/test_topk_softmax.hpp | 3 +- .../topk_softmax/test_topk_softmax_api.cpp | 37 ++++++++++--------- 5 files changed, 46 insertions(+), 42 deletions(-) diff --git a/example/ck_tile/09_topk_softmax/topk_softmax_api.cpp b/example/ck_tile/09_topk_softmax/topk_softmax_api.cpp index bcc04916e6b..770468d36b9 100644 --- a/example/ck_tile/09_topk_softmax/topk_softmax_api.cpp +++ b/example/ck_tile/09_topk_softmax/topk_softmax_api.cpp @@ -3,23 +3,26 @@ #include "topk_softmax_api.hpp" -#define TOPK_SOFTMAX_DISPATCH(experts_, use_softmax_) \ - constexpr ck_tile::index_t ts_experts = experts_; \ - constexpr bool ts_use_softmax = use_softmax_; \ - using ts_problem = ck_tile:: \ - TopkSoftmaxWarpPerRowProblem; \ - using ts_pipeline = ck_tile::TopkSoftmaxWarpPerRowPipeline; \ - \ - using kernel = ck_tile::TopkSoftmaxKernel; \ - \ - auto kargs = kernel::MakeKargs(a); \ - \ - const dim3 grids = kernel::GridSize(a); \ - const dim3 blocks = kernel::BlockSize(); \ - \ - float ave_time = \ - ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(kernel{}, grids, blocks, 0, kargs)); \ - \ +#define TOPK_SOFTMAX_DISPATCH(experts_, use_softmax_) \ + constexpr ck_tile::index_t ts_experts = experts_; \ + constexpr bool ts_use_softmax = use_softmax_; \ + using ts_problem = ck_tile::TopkSoftmaxWarpPerRowProblem; \ + using ts_pipeline = ck_tile::TopkSoftmaxWarpPerRowPipeline; \ + \ + using kernel = ck_tile::TopkSoftmaxKernel; \ + \ + auto kargs = kernel::MakeKargs(a); \ + \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(); \ + \ + float ave_time = \ + ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(kernel{}, grids, blocks, 0, kargs)); \ + \ return ave_time; float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_config s) diff --git a/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp b/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp index 19fdbf46507..677263229b1 100644 --- a/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp +++ b/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp @@ -112,7 +112,6 @@ struct TopkSoftmaxWarpPerRowPipeline // sigmoid was already pre-computed above, so only do topk now topk(w, out_win, idx_win, k); } - // check exit if constexpr(Problem::LaunchType == 0) diff --git a/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp b/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp index 3c5f6b328b4..1dc7e9335e0 100644 --- a/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp +++ b/include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp @@ -13,11 +13,11 @@ template 0, persistent #occupancy - index_t BlockSize_ = 256> + bool ActivationIsSoftmax_ = true, // false: sigmoid + index_t IssuesPerCol_ = 2, // issue along col, to make sure block_reduce() OK + index_t BytesPerIssue_ = sizeof(InputType_), + index_t LaunchType_ = 0, // 0-streaming, >0, persistent #occupancy + index_t BlockSize_ = 256> struct TopkSoftmaxWarpPerRowProblem { // TODO: this kernel only support warp per row diff --git a/test/ck_tile/topk_softmax/test_topk_softmax.hpp b/test/ck_tile/topk_softmax/test_topk_softmax.hpp index cc64aec91ce..73f07035348 100644 --- a/test/ck_tile/topk_softmax/test_topk_softmax.hpp +++ b/test/ck_tile/topk_softmax/test_topk_softmax.hpp @@ -293,8 +293,7 @@ int run_gemm_combinations(std::string const& data_type) {"-t=333", "-e=99", "-k=13", "-st_i=191", "-st_o=17"}, {"-t=20", "-e=5", "-k=2", "-activation=sigmoid"}, {"-t=220", "-e=9", "-k=3", "-activation=sigmoid"}, - {"-t=500", "-e=21", "-k=13", "-activation=sigmoid"} - }; + {"-t=500", "-e=21", "-k=13", "-activation=sigmoid"}}; bool result = true; std::string pr_i = "-pr_i=" + data_type; diff --git a/test/ck_tile/topk_softmax/test_topk_softmax_api.cpp b/test/ck_tile/topk_softmax/test_topk_softmax_api.cpp index 303f7cb10d8..e06935354b0 100644 --- a/test/ck_tile/topk_softmax/test_topk_softmax_api.cpp +++ b/test/ck_tile/topk_softmax/test_topk_softmax_api.cpp @@ -3,23 +3,26 @@ #include "test_topk_softmax_api.hpp" -#define TOPK_SOFTMAX_DISPATCH(experts_, use_softmax_) \ - constexpr ck_tile::index_t ts_experts = experts_; \ - constexpr bool ts_use_softmax = use_softmax_; \ - using ts_problem = ck_tile:: \ - TopkSoftmaxWarpPerRowProblem; \ - using ts_pipeline = ck_tile::TopkSoftmaxWarpPerRowPipeline; \ - \ - using kernel = ck_tile::TopkSoftmaxKernel; \ - \ - auto kargs = kernel::MakeKargs(a); \ - \ - const dim3 grids = kernel::GridSize(a); \ - const dim3 blocks = kernel::BlockSize(); \ - \ - float ave_time = \ - ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(kernel{}, grids, blocks, 0, kargs)); \ - \ +#define TOPK_SOFTMAX_DISPATCH(experts_, use_softmax_) \ + constexpr ck_tile::index_t ts_experts = experts_; \ + constexpr bool ts_use_softmax = use_softmax_; \ + using ts_problem = ck_tile::TopkSoftmaxWarpPerRowProblem; \ + using ts_pipeline = ck_tile::TopkSoftmaxWarpPerRowPipeline; \ + \ + using kernel = ck_tile::TopkSoftmaxKernel; \ + \ + auto kargs = kernel::MakeKargs(a); \ + \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(); \ + \ + float ave_time = \ + ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(kernel{}, grids, blocks, 0, kargs)); \ + \ return ave_time; float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_config s) From 18bd9e9c9b2170b5ef53f3aa53cdee0262fe0a20 Mon Sep 17 00:00:00 2001 From: Sami Remes Date: Mon, 20 Oct 2025 14:22:10 +0000 Subject: [PATCH 3/5] add to changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9de78f30438..b364364c52b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,6 +38,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added tensor-wise quantization for CK_TILE GEMM. * Added support for batched contraction kernel. * Added pooling kernel in CK_TILE +* Added top-k sigmoid kernel in CK_TILE ### Optimized From 906d9062d148c4af3915e106fce21baf79ddf000 Mon Sep 17 00:00:00 2001 From: Sami Remes Date: Tue, 28 Oct 2025 11:50:18 +0200 Subject: [PATCH 4/5] Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- test/ck_tile/topk_softmax/test_topk_softmax_api.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/ck_tile/topk_softmax/test_topk_softmax_api.cpp b/test/ck_tile/topk_softmax/test_topk_softmax_api.cpp index e06935354b0..7e9ed6301ab 100644 --- a/test/ck_tile/topk_softmax/test_topk_softmax_api.cpp +++ b/test/ck_tile/topk_softmax/test_topk_softmax_api.cpp @@ -96,7 +96,7 @@ float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_c } #endif } - if(t.input_type == "fp16" && t.weight_type == "fp32" && t.activation == "sigmoid") + else if(t.input_type == "fp16" && t.weight_type == "fp32" && t.activation == "sigmoid") { using ts_input_type = ck_tile::fp16_t; using ts_weight_type = float; From c795d00333408d34d9861e2f948be77587660308 Mon Sep 17 00:00:00 2001 From: Sami Remes Date: Tue, 28 Oct 2025 11:51:23 +0200 Subject: [PATCH 5/5] Use else if Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- example/ck_tile/09_topk_softmax/topk_softmax_api.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/ck_tile/09_topk_softmax/topk_softmax_api.cpp b/example/ck_tile/09_topk_softmax/topk_softmax_api.cpp index 770468d36b9..bdcfb47cc77 100644 --- a/example/ck_tile/09_topk_softmax/topk_softmax_api.cpp +++ b/example/ck_tile/09_topk_softmax/topk_softmax_api.cpp @@ -96,7 +96,7 @@ float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_c } #endif } - if(t.input_type == "fp16" && t.weight_type == "fp32" && t.activation == "sigmoid") + else if(t.input_type == "fp16" && t.weight_type == "fp32" && t.activation == "sigmoid") { using ts_input_type = ck_tile::fp16_t; using ts_weight_type = float;