diff --git a/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu b/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu index 19eb4be4c16..efa69c70988 100644 --- a/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu +++ b/cpp/tensorrt_llm/kernels/noAuxTcKernels.cu @@ -32,11 +32,14 @@ TRTLLM_NAMESPACE_BEGIN namespace kernels { static constexpr int WARP_SIZE = 32; +static constexpr int NumNemotronExperts = 512; static constexpr int NumKimiK2Experts = 384; static constexpr int NumDeepseekExperts = 256; +static constexpr int MaxSupportedExpertCount = std::max({NumNemotronExperts, NumKimiK2Experts, NumDeepseekExperts}); static constexpr int MaxNumExpertsUnit = 128; static constexpr int NumTopGroupScores = 2; -static constexpr int MaxNumTopExperts = 8; +static constexpr int DefaultMaxNumTopExperts = 8; +static constexpr int MaxSupportedTopExperts = 22; static constexpr int MaxNumTopGroups = 4; static __device__ inline float sigmoid_accurate(float x) @@ -44,7 +47,8 @@ static __device__ inline float sigmoid_accurate(float x) return 0.5f * tanhf(0.5f * x) + 0.5f; } -template +template __global__ void deepseek_v3_topk_kernel(InputT* scores, OutputT* topkValues, IdxT* topkIndices, BiasT* routingBias, int64_t const numTokens, int64_t const numGroup, int64_t const topkGroup, int64_t const topk, int64_t const numExperts, int64_t const numExpertsPerGroup, double const routedScalingFactor) @@ -132,7 +136,7 @@ __global__ void deepseek_v3_topk_kernel(InputT* scores, OutputT* topkValues, Idx /* minValue */ invalidScoreFloat); // get the final group score and write it to shared - if (laneIdx == 0) + if (warp.thread_rank() == 0) { auto groupScore = topExpGroupScores[0] + topExpGroupScores[1]; smemGroupScores[warpIdx] = groupScore; @@ -151,9 +155,7 @@ __global__ void deepseek_v3_topk_kernel(InputT* scores, OutputT* topkValues, Idx reduce_topk::reduceTopK(warp, topGroups, topGroupIdx, groupScore, laneIdx, /* minValue */ invalidScoreFloat); - // final expert selection: get relevant indexes and scores from shared - #pragma unroll for (int ii = 0; ii < MaxNumTopGroups; ++ii) { // bound of numGroup @@ -161,12 +163,11 @@ __global__ void deepseek_v3_topk_kernel(InputT* scores, OutputT* topkValues, Idx expertIdxGroup[ii] = groupIdx * numExpertsPerGroup + laneIdx; expertScoreGroup[ii] - = groupIdx < numGroup && expertSelected ? smemScoreBias[expertIdxGroup[ii]] : invalidScoreFloat; + = (ii < topkGroup) && expertSelected ? smemScoreBias[expertIdxGroup[ii]] : invalidScoreFloat; } - tensorrt_llm::kernels::reduce_topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup, - expertIdxGroup, - /* minValue */ invalidScoreFloat, topk); + tensorrt_llm::kernels::reduce_topk::reduceTopK( + warp, topScores, topExperts, expertScoreGroup, expertIdxGroup, /* minValue */ invalidScoreFloat, topk); } } else if constexpr (MaxNumExperts > MaxNumExpertsUnit) @@ -197,11 +198,16 @@ __global__ void deepseek_v3_topk_kernel(InputT* scores, OutputT* topkValues, Idx smemInterTopScores[warpIdx * MaxNumTopExperts + laneIdx] = topScores[laneIdx]; smemInterTopExperts[warpIdx * MaxNumTopExperts + laneIdx] = topExperts[laneIdx]; } + else if (laneIdx >= topk && laneIdx < MaxNumTopExperts) + { + smemInterTopScores[warpIdx * MaxNumTopExperts + laneIdx] = invalidScoreFloat; + smemInterTopExperts[warpIdx * MaxNumTopExperts + laneIdx] = MaxNumExperts - 1; + } } __syncthreads(); if (warpIdx == 0) { - int constexpr NumInterTopKPerThread = (NumInterTopK * NumExpertWarps - 1) / WARP_SIZE + 1; + int constexpr NumInterTopKPerThread = (NumInterTopK - 1) / WARP_SIZE + 1; float intermidiateScore[NumInterTopKPerThread]; int32_t intermidiateExpert[NumInterTopKPerThread]; for (int i = laneIdx; i < NumInterTopKPerThread * WARP_SIZE; i += WARP_SIZE) @@ -268,11 +274,11 @@ void invokeNoAuxTc(InputT* scores, BiasT* bias, OutputT* topk_values, IdxT* topk { // Check if we can use the optimized deepseek_v3_topk_kernel - bool const is_single_group = (n_group == 1) && (num_experts <= NumKimiK2Experts); + bool const is_single_group = (n_group <= 1) && (num_experts <= MaxSupportedExpertCount); int64_t const experts_per_group = num_experts / n_group; - bool const is_multi_group = (n_group != 1) && (num_experts <= NumDeepseekExperts) - && (experts_per_group <= WARP_SIZE) && (experts_per_group * topk_group <= MaxNumExpertsUnit); + bool const is_multi_group = (n_group > 1) && (num_experts <= NumDeepseekExperts) && (experts_per_group <= WARP_SIZE) + && (experts_per_group * topk_group <= MaxNumExpertsUnit); if (is_single_group || is_multi_group) { @@ -281,7 +287,20 @@ void invokeNoAuxTc(InputT* scores, BiasT* bias, OutputT* topk_values, IdxT* topk int num_threads = NumDeepseekExperts; if (is_single_group) { - if (num_experts > MaxNumExpertsUnit) + // Special case for Nemotron, which selects top 22 from 512 experts, and 1 group only. + if (num_experts == NumNemotronExperts && n_group == 1 && topk == MaxSupportedTopExperts) + { + kernel_instance = &deepseek_v3_topk_kernel; + num_threads = NumNemotronExperts; + } + else if (num_experts > NumKimiK2Experts && num_experts <= MaxSupportedExpertCount) + { + kernel_instance + = &deepseek_v3_topk_kernel; + num_threads = MaxSupportedExpertCount; + } + else if (num_experts > MaxNumExpertsUnit && num_experts <= NumKimiK2Experts) { kernel_instance = &deepseek_v3_topk_kernel; num_threads = NumKimiK2Experts; diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/DevKernel.h b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/DevKernel.h index 7e8c4fb7208..7edc3d1953c 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/DevKernel.h +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/DevKernel.h @@ -182,37 +182,37 @@ namespace moe::dev TLLM_LOG_ERROR("Unsupported dtypeExpW"); \ } -#define LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT( \ - data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag, forceFloatInput, numExperts) \ +#define LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + stream, extraFlag, forceFloatInput, numExperts, numTopExperts) \ if (data.mDtypeExpW == tg::Dtype::Fp32 && extraFlag) \ { \ - LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, true), kernel, numBlocks, numThreads, \ - smemSize, stream); \ + LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, numTopExperts, true), kernel, numBlocks, \ + numThreads, smemSize, stream); \ } \ else if (data.mDtypeExpW == tg::Dtype::Fp32) \ { \ - LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, false), kernel, numBlocks, numThreads, \ - smemSize, stream); \ + LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, numTopExperts, false), kernel, numBlocks, \ + numThreads, smemSize, stream); \ } \ else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag && forceFloatInput) \ { \ - LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, numExperts, true), kernel, numBlocks, \ - numThreads, smemSize, stream); \ + LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, numExperts, numTopExperts, true), kernel, \ + numBlocks, numThreads, smemSize, stream); \ } \ else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag) \ { \ - LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, true), kernel, numBlocks, \ - numThreads, smemSize, stream); \ + LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, numTopExperts, true), \ + kernel, numBlocks, numThreads, smemSize, stream); \ } \ else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && forceFloatInput) \ { \ - LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, numExperts, false), kernel, numBlocks, \ - numThreads, smemSize, stream); \ + LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, numExperts, numTopExperts, false), kernel, \ + numBlocks, numThreads, smemSize, stream); \ } \ else if (data.mDtypeExpW == tg::Dtype::Bfloat16) \ { \ - LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, false), kernel, numBlocks, \ - numThreads, smemSize, stream); \ + LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, numTopExperts, false), \ + kernel, numBlocks, numThreads, smemSize, stream); \ } \ else \ { \ diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingDeepSeek.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingDeepSeek.cu index 462fd5a091e..6937a34ccd9 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingDeepSeek.cu +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingDeepSeek.cu @@ -23,11 +23,13 @@ namespace routingDeepSeek { //////////////////////////////////////////////////////////////////////////////////////////////////// - +static constexpr int NumNemotronExperts = 512; static constexpr int NumKimiK2Experts = 384; static constexpr int NumDeepseekExperts = 256; +static constexpr int MaxSupportedExpertCount = std::max({NumNemotronExperts, NumKimiK2Experts, NumDeepseekExperts}); static constexpr int NumTopGroupScores = 2; -static constexpr int MaxNumTopExperts = 8; +static constexpr int DefaultMaxNumTopExperts = 8; +static constexpr int MaxSupportedTopExperts = 22; static constexpr int MaxNumTopGroups = 4; static constexpr int MaxNumGroups = 8; @@ -125,8 +127,8 @@ __global__ void routingMainKernel(KernelParams params) int32_t topGroupIdx[MaxNumTopGroups]; float expertScoreGroup[MaxNumTopGroups]; int32_t expertIdxGroup[MaxNumTopGroups]; - float topScores[MaxNumTopExperts]; // bound of params.mTopK - int32_t topExperts[MaxNumTopExperts]; + float topScores[KernelParams::MaxNumTopExperts]; // bound of params.mTopK + int32_t topExperts[KernelParams::MaxNumTopExperts]; if constexpr (KernelParams::UseGroups) { @@ -152,7 +154,6 @@ __global__ void routingMainKernel(KernelParams params) topk::reduceTopK(warp, topGroups, topGroupIdx, groupScore, laneIdx, /* minValue */ invalidScoreFloat); // final expert selection: get relevant indexes and scores from shared - #pragma unroll for (int ii = 0; ii < MaxNumTopGroups; ++ii) { // bound of params.mNumLimitedGroups @@ -164,7 +165,8 @@ __global__ void routingMainKernel(KernelParams params) // groupIdx * params.mNumExpertsPerGroup <= params.mNumExperts - params.mNumExpertsPerGroup // => expertIdxGroup[ii] < params.mNumExperts <= NumThreads, // so the access is safe here - expertScoreGroup[ii] = groupIdx < params.mNumExpertGroups && expertSelected + expertScoreGroup[ii] + = (ii < params.mNumLimitedGroups) && (groupIdx < params.mNumExpertGroups) && expertSelected ? smemScoreBias[expertIdxGroup[ii]] : invalidScoreFloat; } @@ -177,7 +179,7 @@ __global__ void routingMainKernel(KernelParams params) { // without groups, each thread just takes `MaxNumTopGroups` experts int constexpr NumExpertWarps = (KernelParams::MaxNumExperts - 1) / topk::MaxNumExpertsUnit + 1; - int constexpr NumInterTopK = NumExpertWarps * MaxNumTopExperts; + int constexpr NumInterTopK = NumExpertWarps * KernelParams::MaxNumTopExperts; __shared__ float __attribute((aligned(128))) smemInterTopScores[NumInterTopK]; __shared__ int32_t __attribute((aligned(128))) smemInterTopExperts[NumInterTopK]; if (warpIdx < NumExpertWarps) @@ -196,14 +198,20 @@ __global__ void routingMainKernel(KernelParams params) if (laneIdx < params.mTopK) { - smemInterTopScores[warpIdx * MaxNumTopExperts + laneIdx] = topScores[laneIdx]; - smemInterTopExperts[warpIdx * MaxNumTopExperts + laneIdx] = topExperts[laneIdx]; + smemInterTopScores[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] = topScores[laneIdx]; + smemInterTopExperts[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] = topExperts[laneIdx]; + } + else if (laneIdx >= params.mTopK && laneIdx < KernelParams::MaxNumTopExperts) + { + smemInterTopScores[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] = invalidScoreFloat; + smemInterTopExperts[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] + = MaxSupportedExpertCount - 1; } } __syncthreads(); if (warpIdx == 0) { - int constexpr NumInterTopKPerThread = (NumInterTopK * NumExpertWarps - 1) / WarpSize + 1; + int constexpr NumInterTopKPerThread = (NumInterTopK - 1) / WarpSize + 1; float intermidiateScore[NumInterTopKPerThread]; int32_t intermidiateExpert[NumInterTopKPerThread]; for (int i = laneIdx; i < NumInterTopKPerThread * WarpSize; i += WarpSize) @@ -295,7 +303,7 @@ __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(Ke cudaGridDependencySynchronize(); } routingPermutation(params, nullptr, warpIdx, clusterBlockRank); + KernelParams::MaxNumTopExperts, /*LoadExpertIdxFromGlobal=*/true>(params, nullptr, warpIdx, clusterBlockRank); } #else __global__ void routingIndicesClusterKernel(KernelParams params) @@ -558,6 +566,10 @@ int constexpr getMaxNumExperts(int32_t numExperts) { return NumKimiK2Experts; } + else if (numExperts <= NumNemotronExperts) + { + return NumNemotronExperts; + } else { TLLM_LOG_ERROR("Unsupported numExperts"); @@ -571,17 +583,30 @@ int constexpr getMaxNumExperts(int32_t numExperts) if (data.mNumExperts <= topk::MaxNumExpertsUnit) \ { \ LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag1, forceFloatInput, topk::MaxNumExpertsUnit); \ + stream, extraFlag1, forceFloatInput, topk::MaxNumExpertsUnit, DefaultMaxNumTopExperts); \ } \ else if (data.mNumExperts <= NumDeepseekExperts) \ { \ LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag1, forceFloatInput, NumDeepseekExperts); \ + stream, extraFlag1, forceFloatInput, NumDeepseekExperts, DefaultMaxNumTopExperts); \ } \ else if (data.mNumExperts <= NumKimiK2Experts) \ { \ LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag1, forceFloatInput, NumKimiK2Experts); \ + stream, extraFlag1, forceFloatInput, NumKimiK2Experts, DefaultMaxNumTopExperts); \ + } \ + else if (data.mNumExperts <= NumNemotronExperts) \ + { \ + if (data.mTopK <= DefaultMaxNumTopExperts) \ + { \ + LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, numThreads, \ + smemSize, stream, extraFlag1, forceFloatInput, NumNemotronExperts, DefaultMaxNumTopExperts); \ + } \ + else if (data.mTopK <= MaxSupportedTopExperts) \ + { \ + LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, numThreads, \ + smemSize, stream, extraFlag1, forceFloatInput, NumNemotronExperts, MaxSupportedTopExperts); \ + } \ } \ else \ { \ @@ -603,25 +628,6 @@ void run(Data& data, void* stream) (data.mPtrTopKPacked != nullptr || data.mPtrTopKIds != nullptr) && data.mPtrPermutedIdxSize, "If permuted index is required, `mPtrTopKPacked` or `mPtrTopKIds` is also required"); TLLM_CHECK_WITH_INFO(!data.mUseRoutingSoftmax, "Routing with softmax not implemented yet"); - TLLM_CHECK_WITH_INFO(data.mNumLimitedGroups <= MaxNumTopGroups, "Routing kernel expects <= %d top groups, got %d", - MaxNumTopGroups, data.mNumLimitedGroups); - TLLM_CHECK_WITH_INFO(data.mTopK <= MaxNumTopExperts, "Routing kernel expects topK experts <= %d, got %d", - MaxNumTopExperts, data.mTopK); - TLLM_CHECK_WITH_INFO(data.mTopK <= WarpSize, "Routing kernel expects top K <= warp size, got %d", data.mTopK); - TLLM_CHECK_WITH_INFO(data.mTopK * data.mNumLimitedGroups <= WarpSize, - "Routing kernel expects top K * top groups <= warp size (for now), got %d * %d", data.mTopK, - data.mNumLimitedGroups); - TLLM_CHECK_WITH_INFO(data.mNumExperts >= MaxNumTopExperts, "Routing kernel expects %d to be at most #experts %d", - MaxNumTopExperts, data.mNumExperts); - TLLM_CHECK_WITH_INFO(data.mNumExperts <= NumKimiK2Experts, "Routing kernel expects #experts %d <= #threads %d", - data.mNumExperts, NumKimiK2Experts); - TLLM_CHECK_WITH_INFO(data.mNumExpertGroups >= data.mNumLimitedGroups, - "Routing kernel expects top groups %d to be limited by #expert groups %d", data.mNumLimitedGroups, - data.mNumExpertGroups); - // Note: Routing-specific constraints (experts per group, topK limits) are checked later - // only when routing is actually needed (data.mPtrTopKIds == nullptr) - TLLM_CHECK_WITH_INFO( - data.mNumExperts % 4 == 0, "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts); int const numBlocks = data.mNumTokens; int const numThreadsHist = getMaxNumExperts(data.mNumExperts); @@ -655,9 +661,18 @@ void run(Data& data, void* stream) int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK; if (data.mPtrTopKIds == nullptr) { + TLLM_CHECK_WITH_INFO(data.mNumExperts >= MaxSupportedTopExperts, + "Routing kernel expects %d to be at most #experts %d", MaxSupportedTopExperts, data.mNumExperts); + TLLM_CHECK_WITH_INFO(data.mNumExperts <= MaxSupportedExpertCount, + "Routing kernel expects #experts %d <= #threads %d", data.mNumExperts, MaxSupportedExpertCount); + TLLM_CHECK_WITH_INFO(data.mTopK <= MaxSupportedTopExperts, "Routing kernel expects topK experts <= %d, got %d", + MaxSupportedTopExperts, data.mTopK); + // Routing needs to be executed - validate routing kernel constraints if (data.mNumExpertGroups > 1) { + // Note: Routing-specific constraints (experts per group, topK limits) are checked when routing is actually + // needed (data.mPtrTopKIds == nullptr) TLLM_CHECK_WITH_INFO(data.mNumExpertGroups <= MaxNumGroups, "Routing kernel expects #expert groups %d to be <= max groups %d", data.mNumExpertGroups, MaxNumGroups); TLLM_CHECK_WITH_INFO(data.mNumExperts % data.mNumExpertGroups == 0, @@ -667,14 +682,17 @@ void run(Data& data, void* stream) "Routing kernel expects #experts per group <= warp size (%d), got %d experts / %d groups = %d experts " "per group", WarpSize, data.mNumExperts, data.mNumExpertGroups, data.mNumExperts / data.mNumExpertGroups); - } - else - { - TLLM_CHECK_WITH_INFO(data.mTopK <= topk::MaxNumTopK, "Routing kernel expects top K %d to be <= max topk %d", - data.mTopK, topk::MaxNumTopK); + TLLM_CHECK_WITH_INFO(data.mNumLimitedGroups <= MaxNumTopGroups, + "Routing kernel expects <= %d top groups, got %d", MaxNumTopGroups, data.mNumLimitedGroups); + + TLLM_CHECK_WITH_INFO(data.mNumExpertGroups >= data.mNumLimitedGroups, + "Routing kernel expects top groups %d to be limited by #expert groups %d", data.mNumLimitedGroups, + data.mNumExpertGroups); + TLLM_CHECK_WITH_INFO(data.mNumExperts % 4 == 0, "Routing kernel expects #experts %d to be a multiple of 4.", + data.mNumExperts); } - int const numThreadsMain = data.mNumExperts < NumDeepseekExperts ? NumDeepseekExperts : NumKimiK2Experts; + int const numThreadsMain = max(data.mNumExpertGroups * WarpSize, getMaxNumExperts(data.mNumExperts)); LAUNCH_ROUTING_DEEPSEEK(data, /*coopLaunch=*/false, routingMainKernel, numBlocks, numThreadsMain, /*smemSize=*/0, // No dynamic smem diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.h b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.h index d5aed6dbc9f..888e04f2541 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.h +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.h @@ -189,13 +189,15 @@ struct Data : public DataBase bool mUseRoutingSoftmax; }; -template +template struct KernelParams : public KernelParamsBase { using InputT = InputT_; using OutputT = OutputT_; static constexpr bool UseGroups = UseGroups_; + static constexpr int MaxNumTopExperts = MaxNumTopExperts_; PackedScoreIdx* mPtrTopKPacked = nullptr; diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernelTopK.cuh b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernelTopK.cuh index 2797baa6a9e..7eab1c82a11 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernelTopK.cuh +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernelTopK.cuh @@ -35,7 +35,7 @@ namespace cg = cooperative_groups; static constexpr int WarpSize = 32; static constexpr int MaxNumExpertsUnit = 128; -static constexpr int MaxNumTopK = 10; +static constexpr int MaxSupportedTopExperts = 22; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingRenormalize.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingRenormalize.cu index 7a9cc1f7323..67b6913aaf7 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingRenormalize.cu +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingRenormalize.cu @@ -25,7 +25,7 @@ static constexpr int NumExpertsLimit = 512; static constexpr int NumThreads = 1024; static constexpr int NumWarps = NumThreads / WarpSize; -static constexpr int MaxNumTopExperts = 10; +static constexpr int MaxSupportedTopExperts = 10; static constexpr int MaxNumTokensSingleCluster = NumBlocksPerCluster * NumThreads; static constexpr int MaxNumTokensSingleClusterScores = NumBlocksPerCluster * NumWarps; @@ -34,8 +34,8 @@ static constexpr int BlockKernelMaxNumTokens = 4; template __forceinline__ __device__ void routingTopKExperts(cg::thread_block_tile const& warp, - DataType (&score)[VecSize], int32_t (&idx)[VecSize], DataType (&warpTopKScore)[MaxNumTopExperts], - int32_t (&warpTopKExpertIdx)[MaxNumTopExperts], int32_t const laneIdx, int32_t const numExperts, int32_t topK, + DataType (&score)[VecSize], int32_t (&idx)[VecSize], DataType (&warpTopKScore)[MaxSupportedTopExperts], + int32_t (&warpTopKExpertIdx)[MaxSupportedTopExperts], int32_t const laneIdx, int32_t const numExperts, int32_t topK, InputType const* ptrScores, bool const normTopkProb, bool const applySoftmaxAfterTopK = true) { DataType minScore = DataType{-INFINITY}; @@ -149,8 +149,8 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) routingIndicesBlo BaseType score[VecSize]; int32_t idx[VecSize]; - BaseType warpTopKScore[MaxNumTopExperts]; - int32_t warpTopKExpertIdx[MaxNumTopExperts]; + BaseType warpTopKScore[MaxSupportedTopExperts]; + int32_t warpTopKExpertIdx[MaxSupportedTopExperts]; BaseType minScore = BaseType{-INFINITY}; if (validToken) @@ -306,7 +306,7 @@ __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(Nu static constexpr int VecSize = KernelParams::MaxNumExperts / WarpSize; - __shared__ TypePacked __attribute((aligned(128))) smemPackedScoreIdx[NumWarps * MaxNumTopExperts]; + __shared__ TypePacked __attribute((aligned(128))) smemPackedScoreIdx[NumWarps * MaxSupportedTopExperts]; uint32_t const clusterBlockRank = blockIdx.x; @@ -332,8 +332,8 @@ __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(Nu BaseType score[VecSize]; int32_t idx[VecSize]; - BaseType warpTopKScore[MaxNumTopExperts]; - int32_t warpTopKExpertIdx[MaxNumTopExperts]; + BaseType warpTopKScore[MaxSupportedTopExperts]; + int32_t warpTopKExpertIdx[MaxSupportedTopExperts]; BaseType minScore = BaseType{-INFINITY}; if (validToken) @@ -356,12 +356,12 @@ __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(Nu if (params.mPtrScores != nullptr) { - routingPermutation(params, smemPackedScoreIdx, warpIdx, clusterBlockRank); } else { - routingPermutation(params, smemPackedScoreIdx, warpIdx, clusterBlockRank); } } @@ -417,8 +417,8 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) routingIndicesHis // over all warps/tokens BaseType allScores[VecSize]; int32_t allExpertIdx[VecSize]; - BaseType warpTopKScore[MaxNumTopExperts]; - int32_t warpTopKExpertIdx[MaxNumTopExperts]; + BaseType warpTopKScore[MaxSupportedTopExperts]; + int32_t warpTopKExpertIdx[MaxSupportedTopExperts]; for (int tokenIdx = globalWarpIdx; tokenIdx < params.mNumTokens; tokenIdx += globalWarpStride) { auto scoreOffset = tokenIdx * params.mNumExperts; @@ -486,8 +486,8 @@ void run(Data const& data, void* stream) TLLM_CHECK_WITH_INFO(data.mPtrPermutedIdxSize != nullptr && data.mPtrCtaIdxXyToBatchIdx != nullptr && data.mPtrCtaIdxXyToMnLimit != nullptr && data.mPtrNumNonExitingCtas != nullptr, "Llama4 routing kernel expects permuted idx and grouped Gemm launch config buffers"); - TLLM_CHECK_WITH_INFO(data.mTopK <= MaxNumTopExperts, "Routing kernel expects topK experts <= %d, got %d", - MaxNumTopExperts, data.mTopK); + TLLM_CHECK_WITH_INFO(data.mTopK <= MaxSupportedTopExperts, "Routing kernel expects topK experts <= %d, got %d", + MaxSupportedTopExperts, data.mTopK); TLLM_CHECK_WITH_INFO(data.mNumExperts <= NumExpertsLimit, "Routing kernel expects #experts %d to be no more than %d", data.mNumExperts, NumExpertsLimit); // static_assert(MaxNumExperts <= NumThreads, "#experts must be bounded by #threads"); diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu index d348d95cb62..81e420ec57a 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu @@ -70,7 +70,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 { if (routingMethodType == RoutingMethodType::DeepSeekV3) { - TLLM_CHECK_WITH_INFO(topK <= 8, "For DeepSeek routing method, must have topK <= 8"); + TLLM_CHECK_WITH_INFO(topK <= 22, "For DeepSeek routing method, must have topK <= 22"); TLLM_CHECK_WITH_INFO(topkGroup <= 4, "For DeepSeek routing method, must have topkGroup <= 4"); moe::dev::routing::routingDeepSeek::Data routingData; routingData.mDtypeExpW = btg::Dtype::Bfloat16; diff --git a/cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp b/cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp index 81746654a4f..c9d9085614a 100644 --- a/cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp +++ b/cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp @@ -106,7 +106,7 @@ std::vector run_fp4_block_scale_moe_runner(torch::optional 1) { TORCH_CHECK(static_cast(routing_method_type) == RoutingMethodType::DeepSeekV3, "Routing kernel with groups implies DeepSeekV3 routing method."); diff --git a/cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp b/cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp index b8e688d1d3d..2db4e2bf6c5 100644 --- a/cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp +++ b/cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp @@ -104,7 +104,7 @@ at::Tensor run_fp8_block_scale_moe(at::optional const& routing_logit TORCH_CHECK(routing_bias.value().sizes()[0] == num_experts, "routing_bias has incorrect shape."); } - if (n_group.has_value() && n_group.value() != 0) + if (n_group.has_value() && n_group.value() > 1) { TORCH_CHECK(static_cast(routing_method_type) == RoutingMethodType::DeepSeekV3, "Routing kernel with groups implies DeepSeekV3 routing method."); diff --git a/cpp/tensorrt_llm/thop/fp8PerTensorScaleMoe.cpp b/cpp/tensorrt_llm/thop/fp8PerTensorScaleMoe.cpp index 9681be6e7af..efefc066321 100644 --- a/cpp/tensorrt_llm/thop/fp8PerTensorScaleMoe.cpp +++ b/cpp/tensorrt_llm/thop/fp8PerTensorScaleMoe.cpp @@ -107,7 +107,7 @@ torch::Tensor fp8_per_tensor_scale_moe_runner(torch::optional con TORCH_CHECK(routing_bias.value().sizes()[0] == num_experts, "routing_bias has incorrect shape."); } - if (n_group.has_value() && n_group.value() != 0) + if (n_group.has_value() && n_group.value() > 1) { TORCH_CHECK(static_cast(routing_method_type) == RoutingMethodType::DeepSeekV3, "Routing kernel with groups implies DeepSeekV3 routing method."); diff --git a/cpp/tensorrt_llm/thop/mxFp4BlockScaleMoe.cpp b/cpp/tensorrt_llm/thop/mxFp4BlockScaleMoe.cpp index 087871593e8..08bce0611b0 100644 --- a/cpp/tensorrt_llm/thop/mxFp4BlockScaleMoe.cpp +++ b/cpp/tensorrt_llm/thop/mxFp4BlockScaleMoe.cpp @@ -114,7 +114,7 @@ torch::Tensor dtype_mxe2m1_block_scale_moe_runner(torch::optional TORCH_CHECK(routing_bias.value().sizes()[0] == num_experts, "routing_bias has incorrect shape."); } - if (n_group.has_value() && n_group.value() != 0) + if (n_group.has_value() && n_group.value() > 1) { TORCH_CHECK(static_cast(routing_method_type) == RoutingMethodType::DeepSeekV3, "Routing kernel with groups implies DeepSeekV3 routing method."); diff --git a/cpp/tests/unit_tests/kernels/routing/routingDeepSeekTest.cpp b/cpp/tests/unit_tests/kernels/routing/routingDeepSeekTest.cpp index 3d82670472d..0467c174965 100644 --- a/cpp/tests/unit_tests/kernels/routing/routingDeepSeekTest.cpp +++ b/cpp/tests/unit_tests/kernels/routing/routingDeepSeekTest.cpp @@ -244,6 +244,17 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization384) this->runTest(param); }; +TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization512) +{ + RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/4, // 1024 + /*numExperts=*/512, /*topK=*/22, + /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, + /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, + /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, + /*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); + this->runTest(param); +}; + TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization) { RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/1024, // 10 @@ -310,6 +321,17 @@ TYPED_TEST(RoutingDeepSeekKernelTest, CooperativeLevelParallelization384) this->runTest(param); }; +TYPED_TEST(RoutingDeepSeekKernelTest, CooperativeLevelParallelization512) +{ + RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/1030, + /*numExperts=*/512, /*topK=*/22, + /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, + /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, + /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, + /*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10); + this->runTest(param); +}; + TYPED_TEST(RoutingDeepSeekKernelTest, DeviceLevelParallelization) { RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/20300, @@ -332,6 +354,17 @@ TYPED_TEST(RoutingDeepSeekKernelTest, DeviceLevelParallelization384) this->runTest(param); }; +TYPED_TEST(RoutingDeepSeekKernelTest, DeviceLevelParallelization512) +{ + RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/20300, + /*numExperts=*/512, /*topK=*/22, + /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, + /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, + /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, + /*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10); + this->runTest(param); +}; + TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationTop2) { RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/10, diff --git a/tensorrt_llm/_torch/modules/fused_moe/routing.py b/tensorrt_llm/_torch/modules/fused_moe/routing.py index d879c6b0031..85e2b2c98d6 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/routing.py +++ b/tensorrt_llm/_torch/modules/fused_moe/routing.py @@ -263,7 +263,8 @@ def noaux_tc(self, logits, e_score_correction_bias): ) self.is_fused = False else: - if num_experts > 384 or self.top_k > 8: + # We have special implementation for n_group == 1, top_k == 22 and num_experts == 512 for Nemotron Super v3. + if num_experts > 512 or (self.top_k > 8 and self.top_k != 22): if (self.is_fused): warnings.warn( "The configuration is not supported by the fused routing kernel. We have to use the original pytorch implementation." @@ -292,7 +293,11 @@ def noaux_tc(self, logits, e_score_correction_bias): score_mask = group_mask.unsqueeze(-1).expand( scores_shape[:-1] + [n_group, scores_shape[-1] // n_group]).reshape(scores_shape) - scores_with_bias = scores_with_bias * score_mask + scores_with_bias = torch.where( + score_mask.bool(), scores_with_bias, + torch.tensor(float('-inf'), + dtype=scores_with_bias.dtype, + device=scores_with_bias.device)) _, topk_idx = torch.topk(scores_with_bias, k=self.top_k, dim=-1, diff --git a/tests/unittest/_torch/thop/parallel/test_noaux_tc.py b/tests/unittest/_torch/thop/parallel/test_noaux_tc.py index d1c44c0ac8c..0e1437034fa 100644 --- a/tests/unittest/_torch/thop/parallel/test_noaux_tc.py +++ b/tests/unittest/_torch/thop/parallel/test_noaux_tc.py @@ -9,6 +9,7 @@ (256, 8, 4, 8), (72, 1, 1, 6), (384, 1, 1, 8), + (512, 1, 1, 22), ]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) diff --git a/tests/unittest/_torch/thop/serial/test_moe.py b/tests/unittest/_torch/thop/serial/test_moe.py index 99ae844fc61..e252fc6047a 100644 --- a/tests/unittest/_torch/thop/serial/test_moe.py +++ b/tests/unittest/_torch/thop/serial/test_moe.py @@ -1008,6 +1008,17 @@ class TestMoeFp4: "routing_method_type": RoutingMethodType.DeepSeekV3 }, id="RoutingDSv3"), + pytest.param( + { + "num_experts": 512, + "top_k": 22, + "n_groups": 1, + "top_k_groups": 1, + "routed_scaling": 2.5, + "has_routing_bias": True, + "routing_method_type": RoutingMethodType.DeepSeekV3 + }, + id="RoutingDS_SuperV3"), pytest.param( { "num_experts": 72, @@ -1238,7 +1249,7 @@ def run_moe_fp4_test(self, num_tokens: int, hidden_size: int, pytest.skip("https://nvbugs/5434352") assert top_k <= num_experts - assert top_k <= 10 + assert top_k <= 22 assert num_experts % 4 == 0 if use_topk_as_input: