Skip to content

Commit 81b16a2

Browse files
authored
[Kernel] Better inf handling for grouped topk cu (#24886)
Signed-off-by: lumina37 <[email protected]>
1 parent e111d5b commit 81b16a2

File tree

1 file changed

+24
-20
lines changed

1 file changed

+24
-20
lines changed

csrc/moe/grouped_topk_kernels.cu

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@
2121
#include <torch/all.h>
2222
#include <cuda_fp16.h>
2323
#include <cuda_bf16.h>
24+
#include <cuda/std/limits>
2425
#include <cooperative_groups.h>
2526
#include <cooperative_groups/reduce.h>
2627
namespace cg = cooperative_groups;
2728

2829
namespace vllm {
2930
namespace moe {
3031

31-
constexpr float kNegInfinity = INFINITY * -1;
3232
constexpr unsigned FULL_WARP_MASK = 0xffffffff;
3333
constexpr int32_t WARP_SIZE = 32;
3434
constexpr int32_t BLOCK_SIZE = 512;
@@ -411,14 +411,21 @@ __device__ inline float cuda_cast<float, __nv_bfloat16>(__nv_bfloat16 val) {
411411
return __bfloat162float(val);
412412
}
413413

414+
template <typename T>
415+
__device__ inline T neg_inf() {
416+
// cuda::std::numeric_limits<T>::infinity() returns `0` for [T=bf16 or fp16]
417+
// so we need to cast from fp32
418+
return cuda_cast<T, float>(-cuda::std::numeric_limits<float>::infinity());
419+
}
420+
414421
template <typename T>
415422
__device__ void topk_with_k2(T* output, T const* input,
416423
cg::thread_block_tile<32> const& tile,
417424
int32_t const lane_id,
418425
int const num_experts_per_group) {
419426
// Get the top2 per thread
420-
T largest = -INFINITY;
421-
T second_largest = -INFINITY;
427+
T largest = neg_inf<T>();
428+
T second_largest = neg_inf<T>();
422429

423430
if (num_experts_per_group > WARP_SIZE) {
424431
for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
@@ -513,8 +520,8 @@ __global__ void group_idx_and_topk_idx_kernel(
513520
warp_id * topk;
514521
s_topk_idx += warp_id * topk;
515522

516-
T value = kNegInfinity;
517-
T topk_group_value = kNegInfinity;
523+
T value = neg_inf<T>();
524+
T topk_group_value = neg_inf<T>();
518525
int32_t num_equalto_topkth_group;
519526

520527
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
@@ -525,11 +532,8 @@ __global__ void group_idx_and_topk_idx_kernel(
525532
if (case_id < num_tokens) {
526533
// calculate group_idx
527534
int32_t target_num_min = WARP_SIZE - n_group + topk_group;
528-
if (lane_id < n_group &&
529-
(isfinite(cuda_cast<float, T>(
530-
group_scores[lane_id])))) // The check is necessary to avoid
531-
// abnormal input
532-
{
535+
// The check is necessary to avoid abnormal input
536+
if (lane_id < n_group && cuda::std::isfinite(group_scores[lane_id])) {
533537
value = group_scores[lane_id];
534538
}
535539

@@ -540,23 +544,22 @@ __global__ void group_idx_and_topk_idx_kernel(
540544
__syncwarp(); // Ensure all threads have valid data before reduction
541545
topk_group_value = cg::reduce(tile, value, cg::greater<T>());
542546
if (value == topk_group_value) {
543-
value = kNegInfinity;
547+
value = neg_inf<T>();
544548
}
545549
pre_count_equal_to_top_value = count_equal_to_top_value;
546-
count_equal_to_top_value = __popc(__ballot_sync(
547-
FULL_WARP_MASK, (value == cuda_cast<T, float>(kNegInfinity))));
550+
count_equal_to_top_value =
551+
__popc(__ballot_sync(FULL_WARP_MASK, (value == neg_inf<T>())));
548552
}
549553
num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value;
550554
}
551555
__syncthreads();
552556

553557
warp_topk::WarpSelect</*capability*/ WARP_SIZE, /*greater*/ true, T, int32_t,
554558
/* is_stable */ true>
555-
queue((int32_t)topk, -INFINITY);
559+
queue((int32_t)topk, neg_inf<T>());
556560

557561
int count_equalto_topkth_group = 0;
558-
bool if_proceed_next_topk =
559-
(topk_group_value != cuda_cast<T, float>(kNegInfinity));
562+
bool if_proceed_next_topk = topk_group_value != neg_inf<T>();
560563
if (case_id < num_tokens && if_proceed_next_topk) {
561564
for (int i_group = 0; i_group < n_group; i_group++) {
562565
if ((group_scores[i_group] > topk_group_value) ||
@@ -566,10 +569,10 @@ __global__ void group_idx_and_topk_idx_kernel(
566569
for (int32_t i = lane_id; i < align_num_experts_per_group;
567570
i += WARP_SIZE) {
568571
T candidates =
569-
(i < num_experts_per_group) && isfinite(cuda_cast<float, T>(
570-
scores_with_bias[offset + i]))
572+
(i < num_experts_per_group) &&
573+
cuda::std::isfinite(scores_with_bias[offset + i])
571574
? scores_with_bias[offset + i]
572-
: cuda_cast<T, float>(kNegInfinity);
575+
: neg_inf<T>();
573576
queue.add(candidates, offset + i);
574577
}
575578
if (group_scores[i_group] == topk_group_value) {
@@ -598,7 +601,8 @@ __global__ void group_idx_and_topk_idx_kernel(
598601
if (i < topk) {
599602
s_topk_value[i] = value;
600603
}
601-
topk_sum += reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
604+
topk_sum +=
605+
cg::reduce(tile, cuda_cast<float, T>(value), cg::plus<float>());
602606
}
603607
}
604608

0 commit comments

Comments
 (0)