From 70dcb88b038fc446c8a7046572e1b2767df1a694 Mon Sep 17 00:00:00 2001 From: Christina Zhang <83400082+ChristinaZ@users.noreply.github.com> Date: Sat, 28 Jun 2025 15:24:24 +0000 Subject: [PATCH] Refactor the topk parallelization Signed-off-by: Christina Zhang <83400082+ChristinaZ@users.noreply.github.com> --- .../blockScaleMoe/DevKernel.h | 18 +- .../blockScaleMoe/RoutingKernel.cu | 772 +++--------------- .../blockScaleMoe/RoutingKernel.h | 45 +- .../blockScaleMoe/RoutingKernelTopK.cuh | 218 +++++ .../trtllmGenKernels/blockScaleMoe/runner.cu | 12 +- .../kernels/routing/routingDeepSeekTest.cpp | 4 +- .../kernels/routing/routingLlama4Test.cpp | 6 +- .../routing/routingRenormalizeTest.cpp | 4 +- .../unit_tests/kernels/routing/routingTest.h | 6 +- 9 files changed, 355 insertions(+), 730 deletions(-) create mode 100644 cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernelTopK.cuh diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/DevKernel.h b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/DevKernel.h index 1c315800cda..7bf2e6b7463 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/DevKernel.h +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/DevKernel.h @@ -171,8 +171,8 @@ namespace moe::dev } \ else if (data.mDtypeExpW == tg::Dtype::Bfloat16) \ { \ - LAUNCH_PDL_QWEN3(data, coopLaunch, LAUCNCH_ESC(void, cutlass::bfloat16_t), kernel, numBlocks, numThreads, \ - smemSize, stream); \ + LAUNCH_PDL_QWEN3( \ + data, coopLaunch, LAUCNCH_ESC(void, __nv_bfloat16), kernel, numBlocks, numThreads, smemSize, stream); \ } \ else \ { \ @@ -186,8 +186,8 @@ namespace moe::dev } \ else if (data.mDtypeExpW == tg::Dtype::Bfloat16) \ { \ - LAUNCH_PDL(data, coopLaunch, LAUCNCH_ESC(void, cutlass::bfloat16_t), kernel, numBlocks, numThreads, smemSize, \ - stream); \ + LAUNCH_PDL( \ + data, coopLaunch, LAUCNCH_ESC(void, __nv_bfloat16), kernel, numBlocks, numThreads, smemSize, stream); \ } \ else \ { \ @@ -201,7 +201,7 @@ namespace moe::dev } \ else if (data.mDtypeExpW == tg::Dtype::Bfloat16) \ { \ - LAUNCH_PDL(data, coopLaunch, cutlass::bfloat16_t, kernel, numBlocks, numThreads, smemSize, stream); \ + LAUNCH_PDL(data, coopLaunch, __nv_bfloat16, kernel, numBlocks, numThreads, smemSize, stream); \ } \ else \ { \ @@ -219,13 +219,13 @@ namespace moe::dev } \ else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && data.mNumExpertGroups > 1) \ { \ - LAUNCH_PDL(data, coopLaunch, LAUCNCH_ESC(cutlass::bfloat16_t, true), kernel, numBlocks, numThreads, smemSize, \ - stream); \ + LAUNCH_PDL( \ + data, coopLaunch, LAUCNCH_ESC(__nv_bfloat16, true), kernel, numBlocks, numThreads, smemSize, stream); \ } \ else if (data.mDtypeExpW == tg::Dtype::Bfloat16) \ { \ - LAUNCH_PDL(data, coopLaunch, LAUCNCH_ESC(cutlass::bfloat16_t, false), kernel, numBlocks, numThreads, smemSize, \ - stream); \ + LAUNCH_PDL( \ + data, coopLaunch, LAUCNCH_ESC(__nv_bfloat16, false), kernel, numBlocks, numThreads, smemSize, stream); \ } \ else \ { \ diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.cu index a4e36624397..9b7fc7d17c0 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.cu +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.cu @@ -16,6 +16,7 @@ #include "DevKernel.h" #include "RoutingKernel.h" +#include "RoutingKernelTopK.cuh" #include #include @@ -40,219 +41,149 @@ static constexpr bool TLLM_GEN_HAS_FAST_REDUX = tensorrt_llm::kernels::arch::is_ namespace routing { -//////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace tg = batchedGemm::trtllm::gen; namespace cg = cooperative_groups; //////////////////////////////////////////////////////////////////////////////////////////////////// -static constexpr int NumThreads = 256; -static constexpr int NumBlocksPerCluster = 8; static constexpr int WarpSize = 32; -static constexpr int NumWarps = NumThreads / WarpSize; -static constexpr int NumTopGroupScores = 2; -static constexpr int MaxNumTopExperts = 8; -static constexpr int MaxNumTopGroups = 4; - +static constexpr int NumBlocksPerCluster = 8; // Performance tuning knob. static constexpr int NumEltsPerOffsetTilePerThread = 8; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct TopKRedType -{ - using TypeExpW = TypeExpW_; - static_assert(std::is_same_v || std::is_same_v, - "Top K reduction only implemented for float and Bf16"); - using TypeCmp = std::conditional_t= 4, double, float>; - static constexpr int64_t Mask64 = 0x000000000000FFFF; - static constexpr int32_t Mask32 = 0x0000FFFF; - - TypeCmp compVal; - - static __host__ __device__ inline TypeCmp makeCmpVal(TypeExpW val, int32_t idx = 0) - { - auto cmpVal = TypeCmp{val}; - TypeCmp cmpValWithIdx; - if constexpr (sizeof(TypeExpW) >= 4) - { - auto cmpValIdx64 = reinterpret_cast(cmpVal) | (Mask64& int64_t{idx}); - cmpValWithIdx = reinterpret_cast(cmpValIdx64); - } - else - { - auto cmpValIdx32 = reinterpret_cast(cmpVal) | (Mask32 & idx); - cmpValWithIdx = reinterpret_cast(cmpValIdx32); - } - return cmpValWithIdx; - } - - static __host__ __device__ inline void unpack(TypeExpW& val, int32_t& idx, TypeCmp cmp) - { - if constexpr (sizeof(TypeExpW) >= 4) - { - idx = static_cast(reinterpret_cast(cmp) & Mask64); - auto val64 = reinterpret_cast(cmp) & ~Mask64; - val = static_cast(reinterpret_cast(val64)); - } - else - { - idx = reinterpret_cast(cmp) & Mask32; - auto val32 = reinterpret_cast(cmp) >> 16; - val = TypeExpW::bitcast(reinterpret_cast(val32)); - } - } +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 1000 && defined(__CUDA_ARCH_FEAT_SM100_ALL)) +#define TLLM_GEN_ENABLE_FAST_REDUX +#endif - __host__ __device__ TopKRedType() = default; +//////////////////////////////////////////////////////////////////////////////////////////////////// - __host__ __device__ TopKRedType(TypeExpW val, int32_t idx) - : compVal(makeCmpVal(val, idx)) - { - } +static __device__ inline float sigmoid_accurate(float x) +{ + return 0.5f * tanhf(0.5f * x) + 0.5f; +} - __host__ __device__ operator TypeCmp() const noexcept - { - return compVal; - } +//////////////////////////////////////////////////////////////////////////////////////////////////// - __device__ inline TypeCmp reduce(cg::thread_block_tile const& warp) - { - if constexpr (!TLLM_GEN_HAS_FAST_REDUX || sizeof(TypeExpW) >= 4) - { - return cg::reduce(warp, compVal, cg::greater{}); - } - else - { - float result; - asm("redux.sync.max.f32 %0, %1, 0xffffffff;\n" : "=f"(result) : "f"(compVal)); - return result; - } - } -}; +template +__host__ __device__ constexpr T mulLog2(T a, T bLog2) +{ + return a << bLog2; +} //////////////////////////////////////////////////////////////////////////////////////////////////// -static __device__ inline float sigmoid_accurate(float x) +template +__host__ __device__ constexpr T divUpLog2(T a, T bLog2) { - return 0.5f * tanhf(0.5f * x) + 0.5f; + return ((a + (1 << bLog2) - 1) >> bLog2); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct TopKIdx +template +__host__ __device__ constexpr T divUpMulLog2(T a, T bLog2) { - // by default, empty -}; + return mulLog2(divUpLog2(a, bLog2), bLog2); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct TopKIdx +__host__ __device__ constexpr int32_t getBits(int32_t value, int idx) { - static constexpr int K = K_; - int32_t val[K]; -}; + int mask = idx == 0 ? 0x000000FF : idx == 1 ? 0x0000FF00 : idx == 2 ? 0x00FF0000 : 0xFF000000; + return (value & mask) >> (idx * 8); +} //////////////////////////////////////////////////////////////////////////////////////////////////// -template -__device__ void reduceTopK(cg::thread_block_tile const& warp, Type (&out)[K], int32_t (&outIdx)[K], - Type value, int32_t idx, Type minValue) +template +__host__ __device__ constexpr void setBits(int32_t& value, int32_t newBits, int idx) { - static_assert(K > 0, "Top K must have K > 0"); - static_assert(K < WarpSize, "Top K must have K < WarpSize"); - using RedType = TopKRedType; - RedType topK{value, idx}; - typename RedType::TypeCmp packedMax{}; -#pragma unroll - for (int kk = 0; kk < K; ++kk) + if constexpr (!IsZero) { - topK = kk > 0 && packedMax == topK.compVal ? RedType{minValue, idx} : topK; - // get the next largest value - packedMax = topK.reduce(warp); - RedType::unpack(out[kk], outIdx[kk], packedMax); + int mask = idx == 0 ? 0xFFFFFF00 : idx == 1 ? 0xFFFF00FF : idx == 2 ? 0xFF00FFFF : 0x00FFFFFF; + value &= mask; } -}; + value |= (newBits << (idx * 8)); +} //////////////////////////////////////////////////////////////////////////////////////////////////// +template +__device__ void calcSoftmax(cg::thread_block_tile const& warp, TypeExpW (&scores)[VecSize]) +{ + TypeExpW maxScore = TypeExpW{-INFINITY}; + TypeExpW sumScore = TypeExpW{0.f}; -#define TOPK_SWAP(I, J) \ - { \ - auto pairMin = min(topK[I].compVal, topK[J].compVal); \ - auto pairMax = max(topK[I].compVal, topK[J].compVal); \ - topK[I].compVal = pairMax; \ - topK[J].compVal = pairMin; \ + // Get the max score for each token + for (int i = 0; i < VecSize; ++i) + { + maxScore = scores[i] >= maxScore ? scores[i] : maxScore; } + maxScore = cg::reduce(warp, maxScore, cg::greater()); -template -__device__ void reduceTopK(cg::thread_block_tile const& warp, Type (&out)[K], int32_t (&outIdx)[K], - Type (&value)[N], int32_t (&idx)[N], Type minValue) -{ - static_assert(K > 0, "Top K must have K > 0"); - static_assert(K < WarpSize, "Top K must have K < WarpSize"); - static_assert(N > 0, "Top K must have N > 1"); - static_assert(N <= K, "Top K must have N < K"); - using RedType = TopKRedType; - RedType topK[N]; + // Get the summation of scores for each token #pragma unroll - for (int nn = 0; nn < N; ++nn) - topK[nn] = RedType{value[nn], idx[nn]}; - if constexpr (!IsSorted) + for (int i = 0; i < VecSize; ++i) { - static_assert(N <= 4, "Unsorted topK expects N <= 4"); - TOPK_SWAP(0, 2); - TOPK_SWAP(1, 3); - - TOPK_SWAP(0, 1); - TOPK_SWAP(2, 3); - - TOPK_SWAP(1, 2); + scores[i] = static_cast(exp(scores[i] - maxScore)); + sumScore += scores[i]; } - typename RedType::TypeCmp packedMax{}; + sumScore = cg::reduce(warp, sumScore, cg::plus()); + + // Normalize the scores #pragma unroll - for (int kk = 0; kk < K; ++kk) + for (int i = 0; i < VecSize; ++i) { - bool update = kk > 0 && packedMax == topK[0].compVal; -#pragma unroll - for (int nn = 0; nn < N; ++nn) - { - topK[nn] = update && nn == N - 1 ? RedType{minValue, idx[nn]} : update ? topK[nn + 1] : topK[nn]; - } - // get the next largest value - packedMax = topK[0].reduce(warp); - RedType::unpack(out[kk], outIdx[kk], packedMax); + scores[i] = static_cast(scores[i] / sumScore); } -}; - -#undef TOPK_SWAP +} //////////////////////////////////////////////////////////////////////////////////////////////////// -template -__host__ __device__ constexpr T mulLog2(T a, T bLog2) +template +__device__ TypeExpW calcSoftmax( + cg::thread_block_tile const& warp, TypeExpW score, int32_t laneIdx, int32_t NumTopExperts) { - return a << bLog2; -} + TypeExpW maxScore = TypeExpW{-INFINITY}; + if (laneIdx < NumTopExperts) + { + maxScore = score >= maxScore ? score : maxScore; + } + maxScore = cg::reduce(warp, maxScore, cg::greater()); -//////////////////////////////////////////////////////////////////////////////////////////////////// + float sumScore = float{0.f}; + float newScore; + // Get the summation of scores for each token + if (laneIdx < NumTopExperts) + { + newScore = static_cast(score) - static_cast(maxScore); + newScore = static_cast(exp(newScore)); + sumScore += newScore; + } + sumScore = cg::reduce(warp, sumScore, cg::plus()); -template -__host__ __device__ constexpr T divUpLog2(T a, T bLog2) -{ - return ((a + (1 << bLog2) - 1) >> bLog2); + if (laneIdx < NumTopExperts) + { + score = static_cast(newScore / sumScore); + } + + return score; } //////////////////////////////////////////////////////////////////////////////////////////////////// -template -__host__ __device__ constexpr T divUpMulLog2(T a, T bLog2) +namespace routingDeepSeek { - return mulLog2(divUpLog2(a, bLog2), bLog2); -} //////////////////////////////////////////////////////////////////////////////////////////////////// +static constexpr int NumThreads = 256; +static constexpr int NumWarps = NumThreads / WarpSize; +static constexpr int NumTopGroupScores = 2; +static constexpr int MaxNumTopExperts = 8; +static constexpr int MaxNumTopGroups = 4; + template __global__ void routingMainKernel(KernelParams params) { @@ -340,7 +271,7 @@ __global__ void routingMainKernel(KernelParams params) if constexpr (KernelParams::UseGroups) { - reduceTopK(warp, topExpGroupScores, topExpGroupIdx, scoreBias, threadExpert, + topk::reduceTopK(warp, topExpGroupScores, topExpGroupIdx, scoreBias, threadExpert, /* minValue */ invalidScoreFloat); // get the final group score and write it to shared @@ -362,7 +293,7 @@ __global__ void routingMainKernel(KernelParams params) { float groupScore = laneIdx < params.mNumExpertGroups ? smemGroupScores[laneIdx] : invalidScoreFloat; - reduceTopK(warp, topGroups, topGroupIdx, groupScore, laneIdx, + topk::reduceTopK(warp, topGroups, topGroupIdx, groupScore, laneIdx, /* minValue */ invalidScoreFloat); // final expert selection: get relevant indexes and scores from shared @@ -396,7 +327,7 @@ __global__ void routingMainKernel(KernelParams params) } } - reduceTopK(warp, topScores, topExperts, expertScoreGroup, expertIdxGroup, + topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup, expertIdxGroup, /* minValue */ invalidScoreFloat); // determine our lane's expert index and write to output @@ -1309,7 +1240,7 @@ void run(Data const& data, void* stream) //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace routing +} // namespace routingDeepSeek //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -1318,15 +1249,8 @@ namespace routingLlama4 //////////////////////////////////////////////////////////////////////////////////////////////////// -namespace tg = batchedGemm::trtllm::gen; -namespace cg = cooperative_groups; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - static constexpr int NumThreads = 1024; static constexpr int NumThreadsHist = 256; -static constexpr int NumBlocksPerCluster = 8; -static constexpr int WarpSize = 32; static constexpr int NumWarps = NumThreads / WarpSize; static constexpr int NumWarpsHist = NumThreadsHist / WarpSize; static constexpr int NumTopExperts = 1; @@ -1339,222 +1263,6 @@ static constexpr int WarpKernelSmemStride = 33; // operations are more efficient end-to-end. static constexpr int WarpKernelMaxNumTokens = 4; -// Performance tuning knob. -static constexpr int NumEltsPerOffsetTilePerThread = 8; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct TopKRedType -{ - using TypeExpW = TypeExpW_; - static_assert(std::is_same_v || std::is_same_v, - "Top K reduction only implemented for float and Bf16"); - using TypeCmp = std::conditional_t= 4, double, float>; - static constexpr int64_t Mask64 = 0x000000000000FFFF; - static constexpr int32_t Mask32 = 0x0000FFFF; - - TypeCmp compVal; - - static __host__ __device__ inline TypeCmp makeCmpVal(TypeExpW val, int32_t idx = 0) - { - auto cmpVal = TypeCmp{val}; - TypeCmp cmpValWithIdx; - if constexpr (sizeof(TypeExpW) >= 4) - { - auto cmpValIdx64 = reinterpret_cast(cmpVal) | (Mask64& int64_t{idx}); - cmpValWithIdx = reinterpret_cast(cmpValIdx64); - } - else - { - auto cmpValIdx32 = reinterpret_cast(cmpVal) | (Mask32 & idx); - cmpValWithIdx = reinterpret_cast(cmpValIdx32); - } - return cmpValWithIdx; - } - - static __host__ __device__ inline void unpack(TypeExpW& val, int32_t& idx, TypeCmp cmp) - { - if constexpr (sizeof(TypeExpW) >= 4) - { - idx = static_cast(reinterpret_cast(cmp) & Mask64); - auto val64 = reinterpret_cast(cmp) & ~Mask64; - val = static_cast(reinterpret_cast(val64)); - } - else - { - idx = reinterpret_cast(cmp) & Mask32; - auto val32 = reinterpret_cast(cmp) >> 16; - val = TypeExpW::bitcast(reinterpret_cast(val32)); - } - } - - __host__ __device__ TopKRedType() = default; - - __host__ __device__ TopKRedType(TypeExpW val, int32_t idx) - : compVal(makeCmpVal(val, idx)) - { - } - - __host__ __device__ operator TypeCmp() const noexcept - { - return compVal; - } - - __device__ inline TypeCmp reduce(cg::thread_block_tile const& warp) - { - if constexpr (!TLLM_GEN_HAS_FAST_REDUX || sizeof(TypeExpW) >= 4) - { - return cg::reduce(warp, compVal, cg::greater{}); - } - else - { - float result; - asm("redux.sync.max.f32 %0, %1, 0xffffffff;\n" : "=f"(result) : "f"(compVal)); - return result; - } - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static __device__ inline float sigmoid_accurate(float x) -{ - return 0.5f * tanhf(0.5f * x) + 0.5f; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct TopKIdx -{ - // by default, empty -}; - -template -struct TopKIdx -{ - static constexpr int K = K_; - int32_t val[K]; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__device__ void reduceTopK(cg::thread_block_tile const& warp, Type (&out)[K], int32_t (&outIdx)[K], - Type value, int32_t idx, Type minValue) -{ - static_assert(K > 0, "Top K must have K > 0"); - static_assert(K < WarpSize, "Top K must have K < WarpSize"); - using RedType = TopKRedType; - RedType topK{value, idx}; - typename RedType::TypeCmp packedMax{}; -#pragma unroll - for (int kk = 0; kk < K; ++kk) - { - topK = kk > 0 && packedMax == topK.compVal ? RedType{minValue, idx} : topK; - // get the next largest value - packedMax = topK.reduce(warp); - RedType::unpack(out[kk], outIdx[kk], packedMax); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define TOPK_SWAP(I, J) \ - { \ - auto pairMin = min(topK[I].compVal, topK[J].compVal); \ - auto pairMax = max(topK[I].compVal, topK[J].compVal); \ - topK[I].compVal = pairMax; \ - topK[J].compVal = pairMin; \ - } - -template -__device__ void reduceTopK(cg::thread_block_tile const& warp, Type (&out)[K], int32_t (&outIdx)[K], - Type (&value)[N], int32_t (&idx)[N], Type minValue) -{ - static_assert(K > 0, "Top K must have K > 0"); - static_assert(K < WarpSize, "Top K must have K < WarpSize"); - static_assert(N > 0, "Top K must have N > 1"); - static_assert(N <= K, "Top K must have N < K"); - using RedType = TopKRedType; - RedType topK[N]; -#pragma unroll - for (int nn = 0; nn < N; ++nn) - topK[nn] = RedType{value[nn], idx[nn]}; - if constexpr (!IsSorted) - { - TOPK_SWAP(0, 2); - TOPK_SWAP(1, 3); - - TOPK_SWAP(0, 1); - TOPK_SWAP(2, 3); - - TOPK_SWAP(1, 2); - } - typename RedType::TypeCmp packedMax{}; -#pragma unroll - for (int kk = 0; kk < K; ++kk) - { - bool update = kk > 0 && packedMax == topK[0].compVal; -#pragma unroll - for (int nn = 0; nn < N; ++nn) - { - topK[nn] = update && nn == N - 1 ? RedType{minValue, idx[nn]} : update ? topK[nn + 1] : topK[nn]; - } - // get the next largest value - packedMax = topK[0].reduce(warp); - RedType::unpack(out[kk], outIdx[kk], packedMax); - } -}; - -#undef TOPK_SWAP - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__host__ __device__ constexpr T mulLog2(T a, T bLog2) -{ - return a << bLog2; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__host__ __device__ constexpr T divUpLog2(T a, T bLog2) -{ - return ((a + (1 << bLog2) - 1) >> bLog2); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__host__ __device__ constexpr T divUpMulLog2(T a, T bLog2) -{ - return mulLog2(divUpLog2(a, bLog2), bLog2); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -__host__ __device__ constexpr int32_t getBits(int32_t value, int idx) -{ - int mask = idx == 0 ? 0x000000FF : idx == 1 ? 0x0000FF00 : idx == 2 ? 0x00FF0000 : 0xFF000000; - return (value & mask) >> (idx * 8); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__host__ __device__ constexpr void setBits(int32_t& value, int32_t newBits, int idx) -{ - if constexpr (!IsZero) - { - int mask = idx == 0 ? 0xFFFFFF00 : idx == 1 ? 0xFFFF00FF : idx == 2 ? 0xFF00FFFF : 0x00FFFFFF; - value &= mask; - } - value |= (newBits << (idx * 8)); -} - //////////////////////////////////////////////////////////////////////////////////////////////////// template @@ -1625,7 +1333,7 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam int32_t warpMaxExpertIdx[NumTopExperts]; TypeExpW warpMaxScore[NumTopExperts]; // warp-wide reduction to get the best score for all experts - reduceTopK(warp, warpMaxScore, warpMaxExpertIdx, maxScore, maxExpertIdx, minScore); + topk::reduceTopK(warp, warpMaxScore, warpMaxExpertIdx, maxScore, maxExpertIdx, minScore); if (cute::elect_one_sync()) { // one thread updates the count linking token to chosen expert @@ -1863,7 +1571,7 @@ __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(Nu } int32_t warpMaxExpertIdx[NumTopExperts]; TypeExpW warpMaxScore[NumTopExperts]; - reduceTopK(warp, warpMaxScore, warpMaxExpertIdx, maxScore, maxExpertIdx, minScore); + topk::reduceTopK(warp, warpMaxScore, warpMaxExpertIdx, maxScore, maxExpertIdx, minScore); if (cute::elect_one_sync()) { TypePacked packedScore{warpMaxScore[0], static_cast(warpMaxExpertIdx[0])}; @@ -2102,7 +1810,7 @@ __global__ void __launch_bounds__(NumThreadsHist) routingIndicesHistogramScoresK } int32_t warpMaxExpertIdx[NumTopExperts]; TypeExpW warpMaxScore[NumTopExperts]; - reduceTopK(warp, warpMaxScore, warpMaxExpertIdx, maxScore, maxExpertIdx, minScore); + topk::reduceTopK(warp, warpMaxScore, warpMaxExpertIdx, maxScore, maxExpertIdx, minScore); if (cute::elect_one_sync()) { TypePacked packedScore{warpMaxScore[0], static_cast(warpMaxExpertIdx[0])}; @@ -2565,17 +2273,10 @@ void run(Data const& data, void* stream) namespace routingRenormalize { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cg = cooperative_groups; - //////////////////////////////////////////////////////////////////////////////////////////////////// static constexpr int NumThreads = 1024; static constexpr int NumThreadsHist = 256; -static constexpr int NumBlocksPerCluster = 8; -static constexpr int WarpSize = 32; static constexpr int NumWarps = NumThreads / WarpSize; static constexpr int NumWarpsHist = NumThreadsHist / WarpSize; static constexpr int NumTopExperts = 8; @@ -2583,284 +2284,6 @@ static constexpr int MaxNumExperts = 128; static constexpr int MaxNumTokensSingleCluster = NumBlocksPerCluster * NumThreads; static constexpr int MaxNumTokensSingleClusterScores = NumBlocksPerCluster * NumWarps; -// Performance tuning knob. -static constexpr int NumEltsPerOffsetTilePerThread = 8; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct TopKRedType -{ - using TypeExpW = TypeExpW_; - static_assert(std::is_same_v || std::is_same_v, - "Top K reduction only implemented for float and Bf16"); - using TypeCmp = std::conditional_t= 4, double, float>; - static constexpr int64_t Mask64 = 0x000000000000FFFF; - static constexpr int32_t Mask32 = 0x0000FFFF; - - TypeCmp compVal; - - static __host__ __device__ inline TypeCmp makeCmpVal(TypeExpW val, int32_t idx = 0) - { - auto cmpVal = TypeCmp{val}; - TypeCmp cmpValWithIdx; - if constexpr (sizeof(TypeExpW) >= 4) - { - auto cmpValIdx64 = reinterpret_cast(cmpVal) | (Mask64& int64_t{idx}); - cmpValWithIdx = reinterpret_cast(cmpValIdx64); - } - else - { - auto cmpValIdx32 = reinterpret_cast(cmpVal) | (Mask32 & idx); - cmpValWithIdx = reinterpret_cast(cmpValIdx32); - } - return cmpValWithIdx; - } - - static __host__ __device__ inline void unpack(TypeExpW& val, int32_t& idx, TypeCmp cmp) - { - if constexpr (sizeof(TypeExpW) >= 4) - { - idx = static_cast(reinterpret_cast(cmp) & Mask64); - auto val64 = reinterpret_cast(cmp) & ~Mask64; - val = static_cast(reinterpret_cast(val64)); - } - else - { - idx = reinterpret_cast(cmp) & Mask32; - auto val32 = reinterpret_cast(cmp) >> 16; - val = TypeExpW::bitcast(reinterpret_cast(val32)); - } - } - - __host__ __device__ TopKRedType() = default; - - __host__ __device__ TopKRedType(TypeExpW val, int32_t idx) - : compVal(makeCmpVal(val, idx)) - { - } - - __host__ __device__ operator TypeCmp() const noexcept - { - return compVal; - } - - __device__ inline TypeCmp reduce(cg::thread_block_tile const& warp) - { - if constexpr (!TLLM_GEN_HAS_FAST_REDUX || sizeof(TypeExpW) >= 4) - { - return cg::reduce(warp, compVal, cg::greater{}); - } - else - { - float result; - asm("redux.sync.max.f32 %0, %1, 0xffffffff;\n" : "=f"(result) : "f"(compVal)); - return result; - } - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct TopKIdx -{ - // by default, empty -}; - -template -struct TopKIdx -{ - static constexpr int K = K_; - int32_t val[K]; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__device__ void reduceTopK(cg::thread_block_tile const& warp, Type (&out)[K], int32_t (&outIdx)[K], - Type value, int32_t idx, Type minValue) -{ - static_assert(K > 0, "Top K must have K > 0"); - static_assert(K < WarpSize, "Top K must have K < WarpSize"); - using RedType = TopKRedType; - RedType topK{value, idx}; - typename RedType::TypeCmp packedMax{}; -#pragma unroll - for (int kk = 0; kk < K; ++kk) - { - topK = kk > 0 && packedMax == topK.compVal ? RedType{minValue, idx} : topK; - // get the next largest value - packedMax = topK.reduce(warp); - RedType::unpack(out[kk], outIdx[kk], packedMax); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define TOPK_SWAP(I, J) \ - { \ - auto pairMin = min(topK[I].compVal, topK[J].compVal); \ - auto pairMax = max(topK[I].compVal, topK[J].compVal); \ - topK[I].compVal = pairMax; \ - topK[J].compVal = pairMin; \ - } - -template -__device__ void reduceTopK(cg::thread_block_tile const& warp, Type (&out)[K], int32_t (&outIdx)[K], - Type (&value)[N], int32_t (&idx)[N], Type minValue) -{ - static_assert(K > 0, "Top K must have K > 0"); - static_assert(K < WarpSize, "Top K must have K < WarpSize"); - static_assert(N > 0, "Top K must have N > 1"); - // static_assert(N <= K, "Top K must have N < K"); - using RedType = TopKRedType; - RedType topK[N]; -#pragma unroll - for (int nn = 0; nn < N; ++nn) - { - topK[nn] = RedType{value[nn], idx[nn]}; - } - - if constexpr (!IsSorted) - { - TOPK_SWAP(0, 2); - TOPK_SWAP(1, 3); - - TOPK_SWAP(0, 1); - TOPK_SWAP(2, 3); - - TOPK_SWAP(1, 2); - } - typename RedType::TypeCmp packedMax{}; -#pragma unroll - for (int kk = 0; kk < K; ++kk) - { - bool update = kk > 0 && packedMax == topK[0].compVal; -#pragma unroll - for (int nn = 0; nn < N; ++nn) - { - topK[nn] = update && nn == N - 1 ? RedType{minValue, idx[nn]} : update ? topK[nn + 1] : topK[nn]; - } - // get the next largest value - packedMax = topK[0].reduce(warp); - RedType::unpack(out[kk], outIdx[kk], packedMax); - } -}; - -#undef TOPK_SWAP - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__host__ __device__ constexpr T mulLog2(T a, T bLog2) -{ - return a << bLog2; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__host__ __device__ constexpr T divUpLog2(T a, T bLog2) -{ - return ((a + (1 << bLog2) - 1) >> bLog2); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__host__ __device__ constexpr T divUpMulLog2(T a, T bLog2) -{ - return mulLog2(divUpLog2(a, bLog2), bLog2); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -__host__ __device__ constexpr int32_t getBits(int32_t value, int idx) -{ - int mask = idx == 0 ? 0x000000FF : idx == 1 ? 0x0000FF00 : idx == 2 ? 0x00FF0000 : 0xFF000000; - return (value & mask) >> (idx * 8); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__host__ __device__ constexpr void setBits(int32_t& value, int32_t newBits, int idx) -{ - if constexpr (!IsZero) - { - int mask = idx == 0 ? 0xFFFFFF00 : idx == 1 ? 0xFFFF00FF : idx == 2 ? 0xFF00FFFF : 0x00FFFFFF; - value &= mask; - } - value |= (newBits << (idx * 8)); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__device__ void calcSoftmax(cg::thread_block_tile const& warp, TypeExpW (&scores)[VecSize]) -{ - TypeExpW maxScore = TypeExpW{-INFINITY}; - TypeExpW sumScore = TypeExpW{0.f}; - - // Get the max score for each token - for (int i = 0; i < VecSize; ++i) - { - maxScore = scores[i] >= maxScore ? scores[i] : maxScore; - } - maxScore = cg::reduce(warp, maxScore, cg::greater()); - - // Get the summation of scores for each token -#pragma unroll - for (int i = 0; i < VecSize; ++i) - { - scores[i] = static_cast(exp(scores[i] - maxScore)); - sumScore += scores[i]; - } - sumScore = cg::reduce(warp, sumScore, cg::plus()); - - // Normalize the scores -#pragma unroll - for (int i = 0; i < VecSize; ++i) - { - scores[i] = static_cast(scores[i] / sumScore); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__device__ TypeExpW calcSoftmax( - cg::thread_block_tile const& warp, TypeExpW score, int32_t laneIdx, int32_t NumTopExperts) -{ - TypeExpW maxScore = TypeExpW{-INFINITY}; - if (laneIdx < NumTopExperts) - { - maxScore = score >= maxScore ? score : maxScore; - } - maxScore = cg::reduce(warp, maxScore, cg::greater()); - - float sumScore = float{0.f}; - float newScore; - // Get the summation of scores for each token - if (laneIdx < NumTopExperts) - { - newScore = static_cast(score) - static_cast(maxScore); - newScore = static_cast(exp(newScore)); - sumScore += newScore; - } - sumScore = cg::reduce(warp, sumScore, cg::plus()); - - if (laneIdx < NumTopExperts) - { - score = static_cast(newScore / sumScore); - } - - return score; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - template #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(NumThreads) @@ -2954,7 +2377,7 @@ __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(Nu } // Get the top-k scores and their corresponding expert indices - reduceTopK(warp, warpTopKScore, warpTopKExpertIdx, score, idx, minScore); + topk::reduceTopK(warp, warpTopKScore, warpTopKExpertIdx, score, idx, minScore); // Normalize the scores if constexpr (DoSoftmaxBeforeTopK) @@ -3253,7 +2676,7 @@ __global__ void __launch_bounds__(NumThreadsHist) routingIndicesHistogramScoresK } // Get the top-k scores and their corresponding expert indices - reduceTopK(warp, warpTopKScore, warpTopKExpertIdx, allScores, allExpertIdx, minScore); + topk::reduceTopK(warp, warpTopKScore, warpTopKExpertIdx, allScores, allExpertIdx, minScore); __syncwarp(); //@TODO: check the synchronization // Normalize the scores @@ -3715,4 +3138,5 @@ void run(Data const& data, void* stream) //////////////////////////////////////////////////////////////////////////////////////////////////// +} // namespace routing } // namespace moe::dev diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.h b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.h index ecd7ce7654b..65b5bcd0ff9 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.h +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.h @@ -26,14 +26,23 @@ namespace moe::dev { -//////////////////////////////////////////////////////////////////////////////////////////////////// - namespace routing { -//////////////////////////////////////////////////////////////////////////////////////////////////// namespace tg = batchedGemm::trtllm::gen; +template +struct PackedScoreIdx +{ + TypeExpW score; + int16_t idx; // @TODO: Might use int8_t as the number of experts is 128 +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace routingDeepSeek +{ + //////////////////////////////////////////////////////////////////////////////////////////////////// struct Data @@ -177,7 +186,7 @@ struct KernelParams void run(Data const& data, void* stream); -} // namespace routing +} // namespace routingDeepSeek //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -186,19 +195,6 @@ namespace routingLlama4 //////////////////////////////////////////////////////////////////////////////////////////////////// -namespace tg = batchedGemm::trtllm::gen; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct PackedScoreIdx -{ - TypeExpW score; - int16_t idx; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - struct Data { tg::Dtype mDtypeExpW{tg::Dtype::Bfloat16}; @@ -310,19 +306,6 @@ namespace routingRenormalize //////////////////////////////////////////////////////////////////////////////////////////////////// -namespace tg = batchedGemm::trtllm::gen; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct PackedScoreIdx -{ - TypeExpW score; - int16_t idx; // @TODO: Might use int8_t as the number of experts is 128 -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - struct Data { tg::Dtype mDtypeExpW{tg::Dtype::Fp32}; @@ -431,5 +414,5 @@ void run(Data const& data, void* stream); } // namespace routingRenormalize //////////////////////////////////////////////////////////////////////////////////////////////////// - +} // namespace routing } // namespace moe::dev diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernelTopK.cuh b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernelTopK.cuh new file mode 100644 index 00000000000..1d9b7c40b29 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernelTopK.cuh @@ -0,0 +1,218 @@ +/* + * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace moe::dev::routing +{ + +namespace topk +{ + +//////////////////////////////////////////////////////////////////////////////////////////////////// +namespace cg = cooperative_groups; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static constexpr int WarpSize = 32; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TopKRedType +{ + using TypeExpW = TypeExpW_; + static_assert( + std::is_same_v || std::is_same_v || std::is_same_v, + "Top K reduction only implemented for float, float16 and bfloat16"); + + using TypeCmp = std::conditional_t; + using IdxT = std::conditional_t; + static constexpr int moveBits = (sizeof(TypeExpW) == 4) ? 32 : 16; + static constexpr int maxIdx = 65535; + + TypeCmp compVal; + + static __host__ __device__ inline TypeCmp makeCmpVal(TypeExpW val, int32_t idx = 0) + { + auto valueBits + = cub::Traits::TwiddleIn(reinterpret_cast::UnsignedBits&>(val)); + TypeCmp compactTmp = reinterpret_cast(valueBits); + compactTmp = (compactTmp << moveBits) | (0xFFFF & (maxIdx - idx)); + // Use 65535 minus idx to give higher priority to elements with smaller indices. + return compactTmp; + } + + static __host__ __device__ inline void unpack(TypeExpW& value, int32_t& index, TypeCmp cmp) + { + // Since idx is always smaller than 65536 and positive, we can directly use it as the lower 16 + // bits + index = maxIdx - static_cast(cmp & 0xFFFF); + + auto compactTmp = cmp >> moveBits; + auto valueBits = cub::Traits::TwiddleOut( + reinterpret_cast::UnsignedBits&>(compactTmp)); + value = reinterpret_cast(valueBits); + } + + __host__ __device__ TopKRedType() = default; + + __host__ __device__ TopKRedType(TypeExpW val, int32_t idx) + : compVal(makeCmpVal(val, idx)) + { + } + + __host__ __device__ operator TypeCmp() const noexcept + { + return compVal; + } + + __device__ inline TypeCmp reduce(cg::thread_block_tile const& warp) + { +#if defined(TLLM_GEN_HAS_FAST_REDUX) + static constexpr bool UseCg = false; +#else + static constexpr bool UseCg = true; +#endif + if constexpr (UseCg || sizeof(TypeCmp) == 8) + { + return cg::reduce(warp, compVal, cg::greater{}); + } + else + { + TypeCmp result; + asm("redux.sync.max.u32 %0, %1, 0xffffffff;\n" : "=r"(result) : "r"(compVal)); + return result; + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define TOPK_SWAP(I, J) \ + { \ + auto pairMin = min(topK[I].compVal, topK[J].compVal); \ + auto pairMax = max(topK[I].compVal, topK[J].compVal); \ + topK[I].compVal = pairMax; \ + topK[J].compVal = pairMin; \ + } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Sort; + +template +struct Sort<1, RedType> +{ + static __device__ void run(RedType* topK) {} +}; + +template +struct Sort<2, RedType> +{ + static __device__ void run(RedType* topK) + { + TOPK_SWAP(0, 1); + } +}; + +template +struct Sort<3, RedType> +{ + static __device__ void run(RedType* topK) + { + TOPK_SWAP(0, 1); + TOPK_SWAP(1, 2); + TOPK_SWAP(0, 1); + } +}; + +template +struct Sort<4, RedType> +{ + static __device__ void run(RedType* topK) + { + TOPK_SWAP(0, 2); + TOPK_SWAP(1, 3); + TOPK_SWAP(0, 1); + TOPK_SWAP(2, 3); + TOPK_SWAP(1, 2); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void reduceTopK(cg::thread_block_tile const& warp, Type (&out)[K], + int32_t (&outIdx)[K], Type value, int32_t idx, Type const minValue) +{ + static_assert(K > 0, "Top K must have K > 0"); + static_assert(K < WarpSize, "Top K must have K < WarpSize"); + using RedType = TopKRedType; + RedType topK{value, idx}; + typename RedType::TypeCmp packedMax{}; +#pragma unroll + for (int kk = 0; kk < K; ++kk) + { + topK = kk > 0 && packedMax == topK.compVal ? RedType{minValue, idx} : topK; + // get the next largest value + packedMax = topK.reduce(warp); + RedType::unpack(out[kk], outIdx[kk], packedMax); + } +}; + +template +__forceinline__ __device__ void reduceTopK(cg::thread_block_tile const& warp, Type (&out)[K], + int32_t (&outIdx)[K], Type (&value)[N], int32_t (&idx)[N], Type const minValue) +{ + static_assert(K > 0, "Top K must have K > 0"); + static_assert(K < WarpSize, "Top K must have K < WarpSize"); + static_assert(N > 0, "Top K must have N > 0"); + static_assert(N < 5, "Only support candidates number less than or equal to 128"); + using RedType = TopKRedType; + RedType topK[N]; +#pragma unroll + for (int nn = 0; nn < N; ++nn) + { + topK[nn] = RedType{value[nn], idx[nn]}; + } + + Sort::run(topK); + + typename RedType::TypeCmp packedMax{}; +#pragma unroll + for (int kk = 0; kk < K; ++kk) + { + bool update = kk > 0 && packedMax == topK[0].compVal; +#pragma unroll + for (int nn = 0; nn < N; ++nn) + { + topK[nn] = update && nn == N - 1 ? RedType{minValue, idx[nn]} : update ? topK[nn + 1] : topK[nn]; + } + // get the next largest value + packedMax = topK[0].reduce(warp); + RedType::unpack(out[kk], outIdx[kk], packedMax); + } +}; + +#undef TOPK_SWAP +} // namespace topk +} // namespace moe::dev::routing diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu index 613e6e1dff2..16b0ca98254 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu @@ -68,7 +68,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 { TLLM_CHECK_WITH_INFO(topK <= 8, "For DeepSeek routing method, must have topK <= 8"); TLLM_CHECK_WITH_INFO(topkGroup <= 4, "For DeepSeek routing method, must have topkGroup <= 4"); - moe::dev::routing::Data routingData; + moe::dev::routing::routingDeepSeek::Data routingData; routingData.mDtypeExpW = btg::Dtype::Bfloat16; routingData.mUsePdl = true; @@ -104,7 +104,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 routingData.mNumLocalExperts = localNumExperts; routingData.mRouteScale = routedScalingFactor; routingData.mUseRoutingSoftmax = false; - moe::dev::routing::run(routingData, stream); + moe::dev::routing::routingDeepSeek::run(routingData, stream); } else if (routingMethodType == RoutingMethodType::Llama4) { @@ -113,7 +113,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 { TLLM_LOG_WARNING("For Llama routing method, nGroup/topkGroup is ignored, got %d/%d.", nGroup, topkGroup); } - moe::dev::routingLlama4::Data routingData; + moe::dev::routing::routingLlama4::Data routingData; routingData.mDtypeExpW = btg::Dtype::Bfloat16; routingData.mUsePdl = true; @@ -149,12 +149,12 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 routingData.mNumLocalExperts = localNumExperts; // routingData.mRouteScale = routed_scaling_factor; // routingData.mUseRoutingSoftmax = false; - moe::dev::routingLlama4::run(routingData, stream); + moe::dev::routing::routingLlama4::run(routingData, stream); } else if (routingMethodType == RoutingMethodType::Renormalize /* default */ || routingMethodType == RoutingMethodType::RenormalizeNaive /* Softmax -> TopK */) { - moe::dev::routingRenormalize::Data routingData; + moe::dev::routing::routingRenormalize::Data routingData; // // Config @@ -196,7 +196,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 routingData.mLocalExpertsStrideLog2 = 0; routingData.mNumLocalExperts = localNumExperts; - moe::dev::routingRenormalize::run(routingData, stream); + moe::dev::routing::routingRenormalize::run(routingData, stream); } else { diff --git a/cpp/tests/unit_tests/kernels/routing/routingDeepSeekTest.cpp b/cpp/tests/unit_tests/kernels/routing/routingDeepSeekTest.cpp index 8d4de68e9c3..e9ed1d0ce90 100644 --- a/cpp/tests/unit_tests/kernels/routing/routingDeepSeekTest.cpp +++ b/cpp/tests/unit_tests/kernels/routing/routingDeepSeekTest.cpp @@ -189,9 +189,9 @@ class RoutingDeepSeekKernelTest : public RoutingKernelTest void callTestedFunction( RoutingKernelTestParam const& param, tensorrt_llm::runtime::ITensor::SharedPtr& workspaceDevice) override { - moe::dev::routing::Data routingData; + moe::dev::routing::routingDeepSeek::Data routingData; setParams(param, routingData); - moe::dev::routing::run(routingData, mStream->get()); + moe::dev::routing::routingDeepSeek::run(routingData, mStream->get()); } void verifyExpertRoutingIndices(RoutingKernelTestParam const& param) diff --git a/cpp/tests/unit_tests/kernels/routing/routingLlama4Test.cpp b/cpp/tests/unit_tests/kernels/routing/routingLlama4Test.cpp index 1996c038449..5f936adf700 100644 --- a/cpp/tests/unit_tests/kernels/routing/routingLlama4Test.cpp +++ b/cpp/tests/unit_tests/kernels/routing/routingLlama4Test.cpp @@ -83,7 +83,7 @@ class RoutingLlama4KernelTest : public RoutingKernelTest [](PackedFloat const& a, PackedFloat const& b) { return ( - (a.score > b.score) || (a.score == b.score && a.idx > b.idx)); //@TODO: check if this is correct + (a.score > b.score) || (a.score == b.score && a.idx < b.idx)); //@TODO: check if this is correct }); // Apply sigmoid to the top-k scores @@ -127,9 +127,9 @@ class RoutingLlama4KernelTest : public RoutingKernelTest void callTestedFunction( RoutingKernelTestParam const& param, tensorrt_llm::runtime::ITensor::SharedPtr& workspaceDevice) override { - moe::dev::routingLlama4::Data routingData; + moe::dev::routing::routingLlama4::Data routingData; setParams(param, routingData); - moe::dev::routingLlama4::run(routingData, mStream->get()); + moe::dev::routing::routingLlama4::run(routingData, mStream->get()); } }; diff --git a/cpp/tests/unit_tests/kernels/routing/routingRenormalizeTest.cpp b/cpp/tests/unit_tests/kernels/routing/routingRenormalizeTest.cpp index f4e2b7c7d4b..55f8e614b5f 100644 --- a/cpp/tests/unit_tests/kernels/routing/routingRenormalizeTest.cpp +++ b/cpp/tests/unit_tests/kernels/routing/routingRenormalizeTest.cpp @@ -171,9 +171,9 @@ class RoutingRenormalizeKernelTest : public RoutingKernelTest void callTestedFunction( RoutingKernelTestParam const& param, tensorrt_llm::runtime::ITensor::SharedPtr& workspaceDevice) override { - moe::dev::routingRenormalize::Data routingData; + moe::dev::routing::routingRenormalize::Data routingData; setParams(param, routingData); - moe::dev::routingRenormalize::run(routingData, mStream->get()); + moe::dev::routing::routingRenormalize::run(routingData, mStream->get()); } }; diff --git a/cpp/tests/unit_tests/kernels/routing/routingTest.h b/cpp/tests/unit_tests/kernels/routing/routingTest.h index 890bae74627..69f9bcec5c4 100644 --- a/cpp/tests/unit_tests/kernels/routing/routingTest.h +++ b/cpp/tests/unit_tests/kernels/routing/routingTest.h @@ -42,7 +42,7 @@ namespace tc = tensorrt_llm::common; namespace tk = tensorrt_llm::kernels; namespace trk = tensorrt_llm::runtime::kernels; -using PackedFloat = moe::dev::routingRenormalize::PackedScoreIdx; +using PackedFloat = moe::dev::routing::PackedScoreIdx; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -197,7 +197,7 @@ inline bool isClose(T* hostData, T* hostTest, int n, std::string const& name, fl inline auto comp = [](PackedFloat const& a, PackedFloat const& b) { - return ((a.score > b.score) || (a.score == b.score && a.idx > b.idx)); //@TODO: check if this is correct + return ((a.score > b.score) || (a.score == b.score && a.idx < b.idx)); //@TODO: check if this is correct }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -326,7 +326,7 @@ class RoutingKernelTest : public testing::Test void runTest(RoutingKernelTestParam const& param); protected: - using PackedType = moe::dev::routingRenormalize::PackedScoreIdx; + using PackedType = moe::dev::routing::PackedScoreIdx; virtual size_t getDeviceWorkspaceSize(RoutingKernelTestParam const& param) {