diff --git a/sgl-kernel/csrc/cpu/topk.cpp b/sgl-kernel/csrc/cpu/topk.cpp index 0471661e58a7..100e87a7c9ce 100644 --- a/sgl-kernel/csrc/cpu/topk.cpp +++ b/sgl-kernel/csrc/cpu/topk.cpp @@ -227,11 +227,9 @@ void topk_softmax_kernel_impl( queue[e] = {scores[e], e}; } - std::partial_sort( - queue.begin(), - queue.begin() + num_experts_per_group, - queue.end(), - [](const elem_t& x, const elem_t& y) -> bool { return x.first > y.first; }); + std::partial_sort(queue.begin(), queue.begin() + topk, queue.end(), [](const elem_t& x, const elem_t& y) -> bool { + return x.first > y.first; + }); for (int64_t j = 0; j < topk; ++j) { topk_weights[i * topk + j] = queue[j].first;