2121#include < torch/all.h>
2222#include < cuda_fp16.h>
2323#include < cuda_bf16.h>
24+ #include < cuda/std/limits>
2425#include < cooperative_groups.h>
2526#include < cooperative_groups/reduce.h>
2627namespace cg = cooperative_groups;
2728
2829namespace vllm {
2930namespace moe {
3031
31- constexpr float kNegInfinity = INFINITY * -1 ;
3232constexpr unsigned FULL_WARP_MASK = 0xffffffff ;
3333constexpr int32_t WARP_SIZE = 32 ;
3434constexpr int32_t BLOCK_SIZE = 512 ;
@@ -411,14 +411,21 @@ __device__ inline float cuda_cast<float, __nv_bfloat16>(__nv_bfloat16 val) {
411411 return __bfloat162float (val);
412412}
413413
414+ template <typename T>
415+ __device__ inline T neg_inf () {
416+ // cuda::std::numeric_limits<T>::infinity() returns `0` for [T=bf16 or fp16]
417+ // so we need to cast from fp32
418+ return cuda_cast<T, float >(-cuda::std::numeric_limits<float >::infinity ());
419+ }
420+
414421template <typename T>
415422__device__ void topk_with_k2 (T* output, T const * input,
416423 cg::thread_block_tile<32 > const & tile,
417424 int32_t const lane_id,
418425 int const num_experts_per_group) {
419426 // Get the top2 per thread
420- T largest = -INFINITY ;
421- T second_largest = -INFINITY ;
427+ T largest = neg_inf<T>() ;
428+ T second_largest = neg_inf<T>() ;
422429
423430 if (num_experts_per_group > WARP_SIZE) {
424431 for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
@@ -513,8 +520,8 @@ __global__ void group_idx_and_topk_idx_kernel(
513520 warp_id * topk;
514521 s_topk_idx += warp_id * topk;
515522
516- T value = kNegInfinity ;
517- T topk_group_value = kNegInfinity ;
523+ T value = neg_inf<T>() ;
524+ T topk_group_value = neg_inf<T>() ;
518525 int32_t num_equalto_topkth_group;
519526
520527#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
@@ -525,11 +532,8 @@ __global__ void group_idx_and_topk_idx_kernel(
525532 if (case_id < num_tokens) {
526533 // calculate group_idx
527534 int32_t target_num_min = WARP_SIZE - n_group + topk_group;
528- if (lane_id < n_group &&
529- (isfinite (cuda_cast<float , T>(
530- group_scores[lane_id])))) // The check is necessary to avoid
531- // abnormal input
532- {
535+ // The check is necessary to avoid abnormal input
536+ if (lane_id < n_group && cuda::std::isfinite (group_scores[lane_id])) {
533537 value = group_scores[lane_id];
534538 }
535539
@@ -540,23 +544,22 @@ __global__ void group_idx_and_topk_idx_kernel(
540544 __syncwarp (); // Ensure all threads have valid data before reduction
541545 topk_group_value = cg::reduce (tile, value, cg::greater<T>());
542546 if (value == topk_group_value) {
543- value = kNegInfinity ;
547+ value = neg_inf<T>() ;
544548 }
545549 pre_count_equal_to_top_value = count_equal_to_top_value;
546- count_equal_to_top_value = __popc ( __ballot_sync (
547- FULL_WARP_MASK, (value == cuda_cast<T, float >( kNegInfinity ))));
550+ count_equal_to_top_value =
551+ __popc ( __ballot_sync ( FULL_WARP_MASK, (value == neg_inf<T>( ))));
548552 }
549553 num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value;
550554 }
551555 __syncthreads ();
552556
553557 warp_topk::WarpSelect</* capability*/ WARP_SIZE, /* greater*/ true , T, int32_t ,
554558 /* is_stable */ true >
555- queue ((int32_t )topk, -INFINITY );
559+ queue ((int32_t )topk, neg_inf<T>() );
556560
557561 int count_equalto_topkth_group = 0 ;
558- bool if_proceed_next_topk =
559- (topk_group_value != cuda_cast<T, float >(kNegInfinity ));
562+ bool if_proceed_next_topk = topk_group_value != neg_inf<T>();
560563 if (case_id < num_tokens && if_proceed_next_topk) {
561564 for (int i_group = 0 ; i_group < n_group; i_group++) {
562565 if ((group_scores[i_group] > topk_group_value) ||
@@ -566,10 +569,10 @@ __global__ void group_idx_and_topk_idx_kernel(
566569 for (int32_t i = lane_id; i < align_num_experts_per_group;
567570 i += WARP_SIZE) {
568571 T candidates =
569- (i < num_experts_per_group) && isfinite (cuda_cast< float , T>(
570- scores_with_bias[offset + i]) )
572+ (i < num_experts_per_group) &&
573+ cuda::std::isfinite ( scores_with_bias[offset + i])
571574 ? scores_with_bias[offset + i]
572- : cuda_cast<T, float >( kNegInfinity );
575+ : neg_inf<T>( );
573576 queue.add (candidates, offset + i);
574577 }
575578 if (group_scores[i_group] == topk_group_value) {
@@ -598,7 +601,8 @@ __global__ void group_idx_and_topk_idx_kernel(
598601 if (i < topk) {
599602 s_topk_value[i] = value;
600603 }
601- topk_sum += reduce (tile, cuda_cast<float , T>(value), cg::plus<float >());
604+ topk_sum +=
605+ cg::reduce (tile, cuda_cast<float , T>(value), cg::plus<float >());
602606 }
603607 }
604608
0 commit comments