diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/DevKernel.h b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/DevKernel.h index 7f8e2b06b00..72996839037 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/DevKernel.h +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/DevKernel.h @@ -233,100 +233,6 @@ namespace moe::dev TLLM_LOG_ERROR("Unsupported pair"); \ } -#define LAUNCH_TILEN(data, coopLaunch, types, kernel, numBlocks, numThreads, smemSize, stream) \ - if (data.mPaddingLog2 > 0) \ - { \ - LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(types, true), kernel, numBlocks, numThreads, smemSize, stream); \ - } \ - else \ - { \ - LAUNCH_PDL(data, coopLaunch, LAUNCH_ESC(types, false), kernel, numBlocks, numThreads, smemSize, stream); \ - } - -#define LAUNCH_ROUTING_LLAMA4(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream) \ - if (data.mDtypeExpW == tg::Dtype::Fp32) \ - { \ - LAUNCH_TILEN(data, coopLaunch, \ - LAUNCH_ESC(float, float, 128 /* Always 128 for llama4*/, 1 /* Always 1 for llama4*/), kernel, numBlocks, \ - numThreads, smemSize, stream); \ - } \ - else if (data.mDtypeExpW == tg::Dtype::Bfloat16) \ - { \ - LAUNCH_TILEN(data, coopLaunch, \ - LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, 128 /* Always 128 for llama4*/, 1 /* Always 1 for llama4*/), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } \ - else \ - { \ - TLLM_LOG_ERROR("Unsupported dtypeExpW"); \ - } - -#define LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag, forceFloatInput, numExperts, numTopExperts) \ - if (data.mDtypeExpW == tg::Dtype::Fp32 && extraFlag) \ - { \ - LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, numTopExperts, true), kernel, numBlocks, \ - numThreads, smemSize, stream); \ - } \ - else if (data.mDtypeExpW == tg::Dtype::Fp32) \ - { \ - LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, numTopExperts, false), kernel, numBlocks, \ - numThreads, smemSize, stream); \ - } \ - else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag && forceFloatInput) \ - { \ - LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, numExperts, numTopExperts, true), kernel, \ - numBlocks, numThreads, smemSize, stream); \ - } \ - else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag) \ - { \ - LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, numTopExperts, true), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } \ - else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && forceFloatInput) \ - { \ - LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, numExperts, numTopExperts, false), kernel, \ - numBlocks, numThreads, smemSize, stream); \ - } \ - else if (data.mDtypeExpW == tg::Dtype::Bfloat16) \ - { \ - LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, numTopExperts, false), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } \ - else \ - { \ - TLLM_LOG_ERROR("Unsupported dtypeExpW"); \ - } - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define LAUNCH_ROUTING_WITH_NUM_EXPERTS( \ - data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, numExperts, numTopExperts) \ - if (data.mDtypeExpW == tg::Dtype::Fp32 && extraFlag1) \ - { \ - LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, numTopExperts, true), kernel, numBlocks, \ - numThreads, smemSize, stream); \ - } \ - else if (data.mDtypeExpW == tg::Dtype::Fp32) \ - { \ - LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, numTopExperts, false), kernel, numBlocks, \ - numThreads, smemSize, stream); \ - } \ - else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag1) \ - { \ - LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, numTopExperts, true), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } \ - else if (data.mDtypeExpW == tg::Dtype::Bfloat16) \ - { \ - LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, numTopExperts, false), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } \ - else \ - { \ - TLLM_LOG_ERROR("Unsupported dtypeExpW"); \ - } - //////////////////////////////////////////////////////////////////////////////////////////////////// namespace activation { diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingDeepSeek.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingDeepSeek.cu deleted file mode 100644 index b59580f9f15..00000000000 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingDeepSeek.cu +++ /dev/null @@ -1,156 +0,0 @@ -/* - * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "routingDeepSeek/RoutingDeepSeekCommon.cuh" - -namespace moe::dev::routing -{ -namespace routingDeepSeek -{ - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// Forward declarations for split-compiled launch wrappers. -void launchMainKernel(Data& data, int numBlocks, int numThreadsMain, void* stream); -void launchInitExpertCounts(Data& data, int numThreadsHist, void* stream); -void launchClusterKernel(Data& data, int numThreadsHist, void* stream); -void launchCoopKernel(Data& data, int numBlocksCoop, int numThreadsHist, void* stream); -void launchHistogramKernel(Data& data, int numBlocksHistogram, int numThreadsHist, void* stream); -void launchOffsetsKernel(Data& data, int numBlocksOffsets, int numThreadsHist, void* stream); - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void run(Data& data, void* stream) -{ - TLLM_CHECK_WITH_INFO(data.mPtrTopKPacked != nullptr || data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr, - "Routing kernel requires at least one input parameter"); - if (data.mPtrTopKIds != nullptr) - { - TLLM_CHECK_WITH_INFO(data.mPtrTopKWeights != nullptr, - "When mPtrTopKIds is provided, mPtrTopKWeights must also be provided for DeepSeek routing."); - } - if (data.mPtrExpandedIdxToPermutedIdx != nullptr || data.mPtrPermutedIdxToExpandedIdx != nullptr - || data.mPtrPermutedIdxToTokenIdx != nullptr) - TLLM_CHECK_WITH_INFO( - (data.mPtrTopKPacked != nullptr || data.mPtrTopKIds != nullptr) && data.mPtrPermutedIdxSize, - "If permuted index is required, `mPtrTopKPacked` or `mPtrTopKIds` is also required"); - TLLM_CHECK_WITH_INFO(!data.mUseRoutingSoftmax, "Routing with softmax not implemented yet"); - int const numBlocks = data.mNumTokens; - int const numThreadsHist = getMaxNumExperts(data.mNumExperts); - - bool const useSingleCluster = data.mNumTokens <= 1024; - if (!useSingleCluster) - { - // Reset the global histograms (not used in single-cluster code path). - // Cover both for the cooperative and two-kernel code paths. - TLLM_CHECK_WITH_INFO( - data.mPtrExpertCounts != nullptr, "When #tokens is large, `mPtrExpertCounts` is a required input."); - } - else - { - data.mPtrExpertCounts = nullptr; // Set it to nullptr for single-cluster code path, as it won't be used - } - - // Number of blocks we can use in the cooperative kernel - // The number of blocks must be: - // >= ⌈(numTokens * topK) / (MaxExpandedIdxPerThread * NumThreads)⌉ - // <= numSms, assuming an occupancy of 1 block/SM - // - // If too small for the given numTokens, fall back to the less performant two-step method. - // - // The upper bound is a strict requirement. The number of blocks should be determined by querying - // the device properties, or conservatively low. - static int const smCount = tensorrt_llm::common::getMultiProcessorCount(); - // WAR: Reserve 8 SMs for overlapping kernels. - int const numBlocksCoop = smCount - 8; - - // Maximum number of tokens supported by the kernel using a cooperative launch. - int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK; - if (data.mPtrTopKIds == nullptr) - { - TLLM_CHECK_WITH_INFO(data.mNumExperts >= MaxSupportedTopExperts, - "Routing kernel expects %d to be at most #experts %d", MaxSupportedTopExperts, data.mNumExperts); - TLLM_CHECK_WITH_INFO(data.mNumExperts <= MaxSupportedExpertCount, - "Routing kernel expects #experts %d <= #threads %d", data.mNumExperts, MaxSupportedExpertCount); - TLLM_CHECK_WITH_INFO(data.mTopK <= MaxSupportedTopExperts, "Routing kernel expects topK experts <= %d, got %d", - MaxSupportedTopExperts, data.mTopK); - - // Routing needs to be executed - validate routing kernel constraints - if (data.mNumExpertGroups > 1) - { - // Note: Routing-specific constraints (experts per group, topK limits) are checked when routing is actually - // needed (data.mPtrTopKIds == nullptr) - TLLM_CHECK_WITH_INFO(data.mNumExpertGroups <= MaxNumGroups, - "Routing kernel expects #expert groups %d to be <= max groups %d", data.mNumExpertGroups, MaxNumGroups); - TLLM_CHECK_WITH_INFO(data.mNumExperts % data.mNumExpertGroups == 0, - "Routing kernel expects #experts %d to be a multiple of #expert groups %d", data.mNumExperts, - data.mNumExpertGroups); - TLLM_CHECK_WITH_INFO(data.mNumExperts / data.mNumExpertGroups <= WarpSize, - "Routing kernel expects #experts per group <= warp size (%d), got %d experts / %d groups = %d experts " - "per group", - WarpSize, data.mNumExperts, data.mNumExpertGroups, data.mNumExperts / data.mNumExpertGroups); - TLLM_CHECK_WITH_INFO(data.mNumLimitedGroups <= MaxNumTopGroups, - "Routing kernel expects <= %d top groups, got %d", MaxNumTopGroups, data.mNumLimitedGroups); - - TLLM_CHECK_WITH_INFO(data.mNumExpertGroups >= data.mNumLimitedGroups, - "Routing kernel expects top groups %d to be limited by #expert groups %d", data.mNumLimitedGroups, - data.mNumExpertGroups); - TLLM_CHECK_WITH_INFO(data.mNumExperts % 4 == 0, "Routing kernel expects #experts %d to be a multiple of 4.", - data.mNumExperts); - } - - int const numThreadsMain = max(data.mNumExpertGroups * WarpSize, getMaxNumExperts(data.mNumExperts)); - launchMainKernel(data, numBlocks, numThreadsMain, stream); - } - else - { - // Reset the global histograms. - launchInitExpertCounts(data, numThreadsHist, stream); - } - - if (data.mPtrPermutedIdxSize != nullptr) - { - if (useSingleCluster) - { - launchClusterKernel(data, numThreadsHist, stream); - } - else if (data.mNumTokens <= maxTokensCoop) - { - launchCoopKernel(data, numBlocksCoop, numThreadsHist, stream); - } - else - { - const int32_t expandedIdxSize = data.mNumTokens * data.mTopK; - const int32_t histogramEltsPerBlock = 8 * numThreadsHist; - const int32_t offsetEltsPerBlock = NumEltsPerOffsetTilePerThread * numThreadsHist; - - // Limit grid size (both kernels use a grid-stride loop). - const int32_t maxNumBlocks = 1024; - - int const numBlocksHistogram - = std::min((expandedIdxSize + histogramEltsPerBlock - 1) / histogramEltsPerBlock, maxNumBlocks); - int const numBlocksOffsets - = std::min((expandedIdxSize + offsetEltsPerBlock - 1) / offsetEltsPerBlock, maxNumBlocks); - - launchHistogramKernel(data, numBlocksHistogram, numThreadsHist, stream); - launchOffsetsKernel(data, numBlocksOffsets, numThreadsHist, stream); - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace routingDeepSeek -} // namespace moe::dev::routing diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingRenormalize.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingRenormalize.cu deleted file mode 100644 index ff4bb808d92..00000000000 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingRenormalize.cu +++ /dev/null @@ -1,109 +0,0 @@ -/* - * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "routingRenormalize/RoutingRenormalizeCommon.cuh" - -namespace moe::dev::routing -{ -namespace routingRenormalize -{ - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// Forward declarations of per-kernel launch wrappers (defined in routingRenormalize/*.cu). -void launchBlockKernel(Data const& data, uint32_t numThreadsHist, void* stream); -void launchClusterKernel(Data const& data, void* stream); -void launchHistogramScoresKernel(Data const& data, uint32_t maxNumBlocks, uint32_t numThreadsHist, void* stream); -void launchInitExpertCounts(Data const& data, uint32_t numThreadsHist, void* stream); -void launchHistogramKernel(Data const& data, int numBlocksHistogram, uint32_t numThreadsHist, void* stream); -void launchOffsetsKernel(Data const& data, int numBlocksOffsets, uint32_t numThreadsHist, void* stream); - -//////////////////////////////////////////////////////////////////////////////////////////////////// -void run(Data const& data, void* stream) -{ - TLLM_CHECK_WITH_INFO(data.mPtrTopKPacked != nullptr || data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr, - "Routing kernel requires at least one input parameter"); - if (data.mPtrTopKIds != nullptr) - { - TLLM_CHECK_WITH_INFO(data.mPtrTopKWeights != nullptr, - "When mPtrTopKIds is provided, mPtrTopKWeights must also be provided for Renormalize routing."); - } - TLLM_CHECK_WITH_INFO(data.mPtrPermutedIdxSize != nullptr && data.mPtrCtaIdxXyToBatchIdx != nullptr - && data.mPtrCtaIdxXyToMnLimit != nullptr && data.mPtrNumNonExitingCtas != nullptr, - "Llama4 routing kernel expects permuted idx and grouped Gemm launch config buffers"); - TLLM_CHECK_WITH_INFO(data.mTopK <= MaxSupportedTopExperts, "Routing kernel expects topK experts <= %d, got %d", - MaxSupportedTopExperts, data.mTopK); - TLLM_CHECK_WITH_INFO(data.mNumExperts <= MaxSupportedExperts, - "Routing kernel expects #experts %d to be no more than %d", data.mNumExperts, MaxSupportedExperts); - // similar check - TLLM_CHECK_WITH_INFO( - data.mNumExperts % 4 == 0, "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts); - - bool const useSingleBlock = data.mNumTokens <= BlockKernelMaxNumTokens - || (data.mNumTokens <= DynBlockKernelMaxNumTokens && data.mNumExperts <= DynBlockKernelMaxNumExperts); - - bool const useSingleCluster = data.mNumTokens <= ((data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr) - ? MaxNumTokensSingleClusterScores - : MaxNumTokensSingleCluster); - - if (!useSingleCluster && !useSingleBlock) - { - TLLM_CHECK_WITH_INFO((data.mPtrTopKPacked != nullptr || data.mPtrTopKIds != nullptr), - "When #tokens is large, `mPtrTopKPacked` or `mPtrTopKIds` is a required input."); - TLLM_CHECK_WITH_INFO( - data.mPtrExpertCounts != nullptr, "When #tokens is large, `mPtrExpertCounts` is a required input."); - } - uint32_t const numThreadsHist = min(1024, getMaxNumExperts(data.mNumExperts)); - if (useSingleBlock) - { - //@TODO: For now we use the single block kernel for cases with token number no larger than 4. - // We will future tune this threshold based on the performance. - launchBlockKernel(data, numThreadsHist, stream); - } - else if (useSingleCluster) - { - launchClusterKernel(data, stream); - } - else - { - uint32_t const expandedIdxSize = data.mNumTokens * data.mTopK; - uint32_t const histogramEltsPerBlock = 8 * numThreadsHist; - uint32_t const offsetEltsPerBlock = NumEltsPerOffsetTilePerThread * numThreadsHist; - - // Limit grid size (all kernels use a grid-stride loop). - uint32_t const maxNumBlocks = 1024; - - int const numBlocksHistogram - = std::min((expandedIdxSize + histogramEltsPerBlock - 1) / histogramEltsPerBlock, maxNumBlocks); - int const numBlocksOffsets - = std::min((expandedIdxSize + offsetEltsPerBlock - 1) / offsetEltsPerBlock, maxNumBlocks); - - if (data.mPtrScores != nullptr && data.mPtrTopKIds == nullptr) - { - launchHistogramScoresKernel(data, maxNumBlocks, numThreadsHist, stream); - } - else - { - // Reset the global histograms. - launchInitExpertCounts(data, numThreadsHist, stream); - } - launchHistogramKernel(data, numBlocksHistogram, numThreadsHist, stream); - launchOffsetsKernel(data, numBlocksOffsets, numThreadsHist, stream); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace routingRenormalize -} // namespace moe::dev::routing diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/IntFastDiv.h b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/IntFastDiv.h similarity index 100% rename from cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/IntFastDiv.h rename to cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/IntFastDiv.h diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingRenormalize/launchBlockKernel.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingCustom.cu similarity index 51% rename from cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingRenormalize/launchBlockKernel.cu rename to cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingCustom.cu index 2a4f9257aa9..93af3fc0ede 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingRenormalize/launchBlockKernel.cu +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingCustom.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,13 +13,92 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "RoutingRenormalizeCommon.cuh" + +// Custom routing: entry point, kernel definitions, and launch wrappers. +// +// Kernel inventory: +// 1. routingIndicesBlockKernel — single-block fused kernel (≤4 tokens) +// 1b. routingIndicesDynBlockKernel — dynamic-block fused kernel (≤16 tokens, ≤512 experts) +// 2. routingIndicesClusterKernel — single-cluster fused kernel (≤256 tokens, SM90+) +// 3. routingIndicesHistogramScoresKernel — TopK + histogram from raw scores +// 4. routingIndicesCoopKernel — cooperative histogram + offsets (defined in RoutingKernel.cuh) +// 5. routingInitExpertCounts — zero expert counts (defined in RoutingKernel.cuh) +// 6. routingIndicesHistogramKernel — histogram from packed TopK (defined in RoutingKernel.cuh) +// 7. routingIndicesOffsetsKernel — prefix-scan + permutation (defined in RoutingKernel.cuh) + +#include "RoutingCustomPolicy.cuh" namespace moe::dev::routing { -namespace routingRenormalize +namespace routingCustom { +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Dual warp-level exclusive prefix scan over NumExpertWarps * 32 values. +// Scans val1 and val2 simultaneously while sharing the same two __syncthreads() barriers, +// reducing 4 barriers (two separate scans) to 2. +//////////////////////////////////////////////////////////////////////////////////////////////////// +template +__device__ __forceinline__ void warpExclusiveScan(int32_t val1, int32_t val2, int32_t laneIdx, int32_t warpIdx, + int32_t* warpTotals1, int32_t* warpTotals2, int32_t& prefix1, int32_t& prefix2, int32_t& totalSum1) +{ + static_assert(NumExpertWarps <= WarpSize, "NumExpertWarps must fit in one warp for the cross-warp scan"); + + int32_t inc1 = val1, inc2 = val2; +#pragma unroll + for (int j = 1; j < WarpSize; j *= 2) + { + int32_t n1 = __shfl_up_sync(0xffffffff, inc1, j); + int32_t n2 = __shfl_up_sync(0xffffffff, inc2, j); + if (laneIdx >= j) + { + inc1 += n1; + inc2 += n2; + } + } + + if (warpIdx < NumExpertWarps && laneIdx == WarpSize - 1) + { + warpTotals1[warpIdx] = inc1; + warpTotals2[warpIdx] = inc2; + } + __syncthreads(); + + if (warpIdx == 0) + { + int32_t wt1 = (laneIdx < NumExpertWarps) ? warpTotals1[laneIdx] : 0; + int32_t wt2 = (laneIdx < NumExpertWarps) ? warpTotals2[laneIdx] : 0; +#pragma unroll + for (int j = 1; j < NumExpertWarps; j *= 2) + { + int32_t n1 = __shfl_up_sync(0xffffffff, wt1, j); + int32_t n2 = __shfl_up_sync(0xffffffff, wt2, j); + if (laneIdx >= j) + { + wt1 += n1; + wt2 += n2; + } + } + if (laneIdx < NumExpertWarps) + { + warpTotals1[laneIdx] = wt1; + warpTotals2[laneIdx] = wt2; + } + } + __syncthreads(); + + totalSum1 = warpTotals1[NumExpertWarps - 1]; + int32_t wp1 = (warpIdx > 0 && warpIdx < NumExpertWarps) ? warpTotals1[warpIdx - 1] : 0; + int32_t wp2 = (warpIdx > 0 && warpIdx < NumExpertWarps) ? warpTotals2[warpIdx - 1] : 0; + prefix1 = inc1 - val1 + wp1; + prefix2 = inc2 - val2 + wp2; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// 1. Block kernel — single-block fused kernel for ≤4 tokens. +// Fuses TopK, histogram, prefix-scan, and permutation in one block. +// //////////////////////////////////////////////////////////////////////////////////////////////////// template @@ -29,7 +108,7 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa // types used in this kernel using OutputT = typename KernelParams::OutputT; using InputT = typename KernelParams::InputT; - using BaseType = std::conditional_t; + using BaseType = typename KernelParams::ExpertSelectPolicy::template BaseType; using TypePacked = PackedScoreIdx; static constexpr int MaxNumExperts = KernelParams::MaxNumExperts; // When MaxNumExperts > 1024, cap actual thread count at 1024 and let each thread handle @@ -63,7 +142,7 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) // then wait on primary grid - if constexpr (KernelParams::UsePdl) + if (params.mUsePdl) { cudaGridDependencySynchronize(); } @@ -76,7 +155,7 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa if (laneIdx < params.mTopK) { auto expertIdx = params.mPtrTopKIds[warpIdx * params.mTopK + laneIdx]; - if (expertIdx != -1) + if (expertIdx > -1 && expertIdx < params.mNumExperts) { int offset = warpIdx * MaxNumExperts + expertIdx; smemKIdx[offset] = static_cast(laneIdx); @@ -91,18 +170,14 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa else if (params.mPtrScores != nullptr) { // in this case, each warp represents a token - BaseType score[VecSize]; - int32_t idx[VecSize]; - BaseType warpTopKScore[KernelParams::MaxNumTopExperts]; int32_t warpTopKExpertIdx[KernelParams::MaxNumTopExperts]; - BaseType minScore = BaseType{-INFINITY}; if (validToken) { - routingTopKExperts(warp, score, idx, warpTopKScore, warpTopKExpertIdx, laneIdx, - params.mNumExperts, params.mTopK, params.mPtrScores + scoreOffset, params.mNormTopkProb); + KernelParams::ExpertSelectPolicy::template apply( + warp, warpTopKScore, warpTopKExpertIdx, laneIdx, params.mNumExperts, params.mTopK, + params.mPtrScores + scoreOffset, params); if (laneIdx < params.mTopK) { @@ -115,6 +190,31 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa } } // end if (validToken) } + else if (params.mPtrTopKPacked != nullptr) + { + if (validToken) + { + if (laneIdx < params.mTopK) + { + auto const expandedIdx = warpIdx * params.mTopK + laneIdx; + auto const scoreIdx = params.mPtrTopKPacked[expandedIdx]; + int const expertIdx = static_cast(scoreIdx.idx); + if (expertIdx >= 0 && expertIdx < params.mNumExperts) + { + int const offset = warpIdx * MaxNumExperts + expertIdx; + smemKIdx[offset] = static_cast(laneIdx); + if (params.mPtrTopKWeights != nullptr) + { + params.mPtrTopKWeights[expandedIdx] = static_cast(scoreIdx.score); + } + } + else if (params.mPtrExpandedIdxToPermutedIdx != nullptr) + { + params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = int32_t{-1}; + } + } + } + } __syncthreads(); // Each thread handles ExpertsPerThread contiguous experts. @@ -155,7 +255,7 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa #pragma unroll for (int e = 0; e < ExpertsPerThread; e++) { - if constexpr (KernelParams::isPow2) + if (params.mIsPow2) { numCtaPerExpert[e] = divUpLog2(accExpertCount[e], params.mPaddingLog2); } @@ -174,7 +274,7 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa #pragma unroll for (int e = 0; e < ExpertsPerThread; e++) { - if constexpr (KernelParams::isPow2) + if (params.mIsPow2) { tmpCountPerExpert[e] = divUpMulLog2(accExpertCount[e], params.mPaddingLog2); } @@ -205,7 +305,7 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa params.mPtrCtaIdxXyToBatchIdx[ctaOffsetPerExpert[e] + cta] = mappedLocalIdx; int32_t mnLimit1; int32_t mnLimit2; - if constexpr (KernelParams::isPow2) + if (params.mIsPow2) { mnLimit1 = mulLog2(ctaOffsetPerExpert[e] + cta + 1, params.mPaddingLog2); mnLimit2 = mulLog2(ctaOffsetPerExpert[e], params.mPaddingLog2) + accExpertCount[e]; @@ -220,11 +320,10 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa } } - // at this point, we can write out padded count if (threadIdx.x == 0) { int32_t permutedIdxSize; - if constexpr (KernelParams::isPow2) + if (params.mIsPow2) { permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); } @@ -236,14 +335,6 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa params.mPtrNumNonExitingCtas[0] = numNonExitingCtas; } -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - // we can trigger the next kernel at this point - if constexpr (KernelParams::UsePdl) - { - cudaTriggerProgrammaticLaunchCompletion(); - } -#endif - for (int tokenIdx = 0; tokenIdx < params.mNumTokens; tokenIdx++) { #pragma unroll @@ -277,85 +368,43 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa } } } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// Dual warp-level exclusive prefix scan over NumExpertWarps * 32 values. -// Scans val1 and val2 simultaneously while sharing the same two __syncthreads() barriers, -// reducing 4 barriers (two separate scans) to 2. -//////////////////////////////////////////////////////////////////////////////////////////////////// -template -__device__ __forceinline__ void warpExclusiveScan(int32_t val1, int32_t val2, int32_t laneIdx, int32_t warpIdx, - int32_t* warpTotals1, int32_t* warpTotals2, int32_t& prefix1, int32_t& prefix2, int32_t& totalSum1) -{ - static_assert(NumExpertWarps <= WarpSize, "NumExpertWarps must fit in one warp for the cross-warp scan"); - - int32_t inc1 = val1, inc2 = val2; -#pragma unroll - for (int j = 1; j < WarpSize; j *= 2) - { - int32_t n1 = __shfl_up_sync(0xffffffff, inc1, j); - int32_t n2 = __shfl_up_sync(0xffffffff, inc2, j); - if (laneIdx >= j) - { - inc1 += n1; - inc2 += n2; - } - } - - if (warpIdx < NumExpertWarps && laneIdx == WarpSize - 1) - { - warpTotals1[warpIdx] = inc1; - warpTotals2[warpIdx] = inc2; - } - __syncthreads(); - if (warpIdx == 0) +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + // Trigger the secondary kernel AFTER all global memory writes (including permutation indices). + // The downstream kernels depend on all routing outputs being visible. + if (params.mUsePdl) { - int32_t wt1 = (laneIdx < NumExpertWarps) ? warpTotals1[laneIdx] : 0; - int32_t wt2 = (laneIdx < NumExpertWarps) ? warpTotals2[laneIdx] : 0; -#pragma unroll - for (int j = 1; j < NumExpertWarps; j *= 2) - { - int32_t n1 = __shfl_up_sync(0xffffffff, wt1, j); - int32_t n2 = __shfl_up_sync(0xffffffff, wt2, j); - if (laneIdx >= j) - { - wt1 += n1; - wt2 += n2; - } - } - if (laneIdx < NumExpertWarps) - { - warpTotals1[laneIdx] = wt1; - warpTotals2[laneIdx] = wt2; - } + cudaTriggerProgrammaticLaunchCompletion(); } - __syncthreads(); +#endif +} - totalSum1 = warpTotals1[NumExpertWarps - 1]; - int32_t wp1 = (warpIdx > 0 && warpIdx < NumExpertWarps) ? warpTotals1[warpIdx - 1] : 0; - int32_t wp2 = (warpIdx > 0 && warpIdx < NumExpertWarps) ? warpTotals2[warpIdx - 1] : 0; - prefix1 = inc1 - val1 + wp1; - prefix2 = inc2 - val2 + wp2; +void launchBlockKernel(Data const& data, uint32_t numThreadsHist, void* stream) +{ + LAUNCH_ROUTING_CUSTOM(data, false, routingIndicesBlockKernel, 1, numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream); } //////////////////////////////////////////////////////////////////////////////////////////////////// -// Dynamic-block routing kernel: uses a dynamic thread count and dynamic shared memory. // -// Compared to routingIndicesBlockKernel (which fixes blockDim = MaxExperts): -// 1. Thread count = min(max(numTokens*32, MaxExperts), 1024) so each token -// gets its own warp — eliminates the Phase-1 TopK batch loop for small batches. -// 2. Warp-level Hillis-Steele scan replaces CUB BlockScan, fusing two scans -// into one (2 barriers instead of 4) with no compile-time thread count dependency. -// 3. Dynamic shared memory enables flexible token counts (up to 16). +// 1b. Dynamic-block kernel — single-block with dynamic thread count and dynamic shared memory. +// +// Compared to routingIndicesBlockKernel (which fixes blockDim = MaxExperts): +// 1. Thread count = min(max(numTokens*32, MaxExperts), 1024) so each token +// gets its own warp — eliminates the Phase-1 TopK batch loop for small batches. +// 2. Warp-level Hillis-Steele scan replaces CUB BlockScan, fusing two scans +// into one (2 barriers instead of 4) with no compile-time thread count dependency. +// 3. Dynamic shared memory enables flexible token counts (up to 16). +// //////////////////////////////////////////////////////////////////////////////////////////////////// + template __global__ void routingIndicesDynBlockKernel(KernelParams params) { using OutputT = typename KernelParams::OutputT; using InputT = typename KernelParams::InputT; - using BaseType = std::conditional_t; + using BaseType = typename KernelParams::ExpertSelectPolicy::template BaseType; using TypePacked = PackedScoreIdx; static constexpr int MaxNumExperts = KernelParams::MaxNumExperts; static constexpr int NumThreadsExperts = MaxNumExperts <= 1024 ? MaxNumExperts : 1024; @@ -370,11 +419,6 @@ __global__ void routingIndicesDynBlockKernel(KernelParams params) int32_t const laneIdx = cutlass::arch::LaneId(); int32_t const numWarps = blockDim.x / WarpSize; - // Dynamic shared memory layout: - // [0 .. numSlots) : int8_t smemKIdx - // [numSlots .. numSlots*3) : int16_t smemOffset - // [aligned .. +NumExpertWarps] : int32_t warpTotals1 (scan: numCtaPerExpert) - // [+NumExpertWarps] : int32_t warpTotals2 (scan: tmpCountPerExpert) extern __shared__ char dynSmem[]; int const numSlots = params.mNumTokens * MaxNumExperts; int8_t* smemKIdx = reinterpret_cast(dynSmem); @@ -387,8 +431,6 @@ __global__ void routingIndicesDynBlockKernel(KernelParams params) auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); - // Initialize smemKIdx only — smemOffset is only read when kIdx >= 0, - // which implies Phase 2 has already written it (no init needed). for (int i = threadIdx.x; i < numSlots; i += blockDim.x) { smemKIdx[i] = int8_t{-1}; @@ -396,13 +438,13 @@ __global__ void routingIndicesDynBlockKernel(KernelParams params) __syncthreads(); #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - if constexpr (KernelParams::UsePdl) + if (params.mUsePdl) { cudaGridDependencySynchronize(); } #endif - // ── Phase 1: TopK — one warp per token (loop only when numTokens > numWarps) ── + // Phase 1: TopK — one warp per token (loop only when numTokens > numWarps) for (int tokenIdx = warpIdx; tokenIdx < params.mNumTokens; tokenIdx += numWarps) { if (params.mPtrTopKIds != nullptr) @@ -414,7 +456,7 @@ __global__ void routingIndicesDynBlockKernel(KernelParams params) { smemKIdx[tokenIdx * MaxNumExperts + expertIdx] = static_cast(laneIdx); } - else + else if (params.mPtrExpandedIdxToPermutedIdx != nullptr) { params.mPtrExpandedIdxToPermutedIdx[tokenIdx * params.mTopK + laneIdx] = int32_t{-1}; } @@ -422,15 +464,13 @@ __global__ void routingIndicesDynBlockKernel(KernelParams params) } else if (params.mPtrScores != nullptr) { - BaseType score[VecSize]; - int32_t idx[VecSize]; BaseType warpTopKScore[KernelParams::MaxNumTopExperts]; int32_t warpTopKExpertIdx[KernelParams::MaxNumTopExperts]; auto scoreOff = tokenIdx * params.mNumExperts; - routingTopKExperts(warp, score, idx, warpTopKScore, warpTopKExpertIdx, laneIdx, - params.mNumExperts, params.mTopK, params.mPtrScores + scoreOff, params.mNormTopkProb); + KernelParams::ExpertSelectPolicy::template apply( + warp, warpTopKScore, warpTopKExpertIdx, laneIdx, params.mNumExperts, params.mTopK, + params.mPtrScores + scoreOff, params); if (laneIdx < params.mTopK) { @@ -447,17 +487,25 @@ __global__ void routingIndicesDynBlockKernel(KernelParams params) { auto expandedIdx = tokenIdx * params.mTopK + laneIdx; auto scoreIdx = params.mPtrTopKPacked[expandedIdx]; - smemKIdx[tokenIdx * MaxNumExperts + static_cast(scoreIdx.idx)] = static_cast(laneIdx); - if (params.mPtrTopKWeights != nullptr) + int const expertIdx = static_cast(scoreIdx.idx); + if (expertIdx >= 0 && expertIdx < params.mNumExperts) + { + smemKIdx[tokenIdx * MaxNumExperts + expertIdx] = static_cast(laneIdx); + if (params.mPtrTopKWeights != nullptr) + { + params.mPtrTopKWeights[expandedIdx] = static_cast(scoreIdx.score); + } + } + else if (params.mPtrExpandedIdxToPermutedIdx != nullptr) { - params.mPtrTopKWeights[expandedIdx] = static_cast(scoreIdx.score); + params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = int32_t{-1}; } } } } __syncthreads(); - // ── Phase 2: Histogram — each expert-thread counts tokens assigned to its expert(s) ── + // Phase 2: Histogram int accExpertCount[ExpertsPerThread]; if (threadIdx.x < NumThreadsExperts) { @@ -493,7 +541,7 @@ __global__ void routingIndicesDynBlockKernel(KernelParams params) } } - // ── Phase 3: Prefix-scan (merged dual warp-level scan, 2 barriers instead of 4) ── + // Phase 3: Prefix-scan (merged dual warp-level scan, 2 barriers instead of 4) int32_t numCtaPerExpert[ExpertsPerThread]; int32_t tmpCountPerExpert[ExpertsPerThread]; int32_t ctaOffsetPerExpert[ExpertsPerThread]; @@ -505,7 +553,7 @@ __global__ void routingIndicesDynBlockKernel(KernelParams params) { if (threadIdx.x < NumThreadsExperts) { - if constexpr (KernelParams::isPow2) + if (params.mIsPow2) { numCtaPerExpert[e] = divUpLog2(accExpertCount[e], params.mPaddingLog2); tmpCountPerExpert[e] = divUpMulLog2(accExpertCount[e], params.mPaddingLog2); @@ -546,7 +594,7 @@ __global__ void routingIndicesDynBlockKernel(KernelParams params) } } - // ── Phase 4: CTA configs ── + // Phase 4: CTA configs if (threadIdx.x < NumThreadsExperts) { #pragma unroll @@ -564,7 +612,7 @@ __global__ void routingIndicesDynBlockKernel(KernelParams params) = (expert - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2; params.mPtrCtaIdxXyToBatchIdx[ctaOffsetPerExpert[e] + cta] = mappedLocalIdx; int32_t mnLimit1, mnLimit2; - if constexpr (KernelParams::isPow2) + if (params.mIsPow2) { mnLimit1 = mulLog2(ctaOffsetPerExpert[e] + cta + 1, params.mPaddingLog2); mnLimit2 = mulLog2(ctaOffsetPerExpert[e], params.mPaddingLog2) + accExpertCount[e]; @@ -583,7 +631,7 @@ __global__ void routingIndicesDynBlockKernel(KernelParams params) if (threadIdx.x == 0) { int32_t permutedIdxSize; - if constexpr (KernelParams::isPow2) + if (params.mIsPow2) { permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); } @@ -596,13 +644,13 @@ __global__ void routingIndicesDynBlockKernel(KernelParams params) } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - if constexpr (KernelParams::UsePdl) + if (params.mUsePdl) { cudaTriggerProgrammaticLaunchCompletion(); } #endif - // ── Phase 5: Permutation ── + // Phase 5: Permutation if (threadIdx.x < NumThreadsExperts) { for (int tokenIdx = 0; tokenIdx < params.mNumTokens; tokenIdx++) @@ -641,30 +689,364 @@ __global__ void routingIndicesDynBlockKernel(KernelParams params) } } +void launchDynBlockKernel(Data const& data, uint32_t numThreadsHist, void* stream) +{ + int32_t const maxExperts = queryDispatchedMaxExperts(data); + int const numSlots = data.mNumTokens * maxExperts; + int const smemSize + = numSlots + numSlots * 2 + 128 + 2 * (maxExperts / WarpSize) * static_cast(sizeof(int32_t)); + int const threads = std::min(std::max(data.mNumTokens * static_cast(WarpSize), maxExperts), 1024); + + LAUNCH_ROUTING_CUSTOM(data, false, routingIndicesDynBlockKernel, 1, threads, smemSize, stream); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// 2. Cluster kernel — single-cluster fused kernel for ≤256 tokens (SM90+). +// Uses distributed shared memory across 8 blocks in a cluster. +// //////////////////////////////////////////////////////////////////////////////////////////////////// -void launchBlockKernel(Data const& data, uint32_t numThreadsHist, void* stream) +template +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) +__global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(NumThreads) + routingIndicesClusterKernel(KernelParams params) +{ + using OutputT = typename KernelParams::OutputT; + using InputT = typename KernelParams::InputT; + using BaseType = typename KernelParams::ExpertSelectPolicy::template BaseType; + using TypePacked = PackedScoreIdx; + static constexpr int VecSize = KernelParams::MaxNumExperts / WarpSize; + + __shared__ TypePacked __attribute((aligned(128))) smemPackedScoreIdx[NumWarps * KernelParams::MaxNumTopExperts]; + + uint32_t const clusterBlockRank = blockIdx.x; + int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); + int32_t const laneIdx = cutlass::arch::LaneId(); + auto warpTokenIdx = clusterBlockRank * NumWarps + warpIdx; + auto scoreOffset = warpTokenIdx * params.mNumExperts; + bool validToken = warpTokenIdx < params.mNumTokens; + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + if (params.mUsePdl) + { + cudaGridDependencySynchronize(); + } + + if (params.mPtrScores != nullptr) + { + BaseType warpTopKScore[KernelParams::MaxNumTopExperts]; + int32_t warpTopKExpertIdx[KernelParams::MaxNumTopExperts]; + if (validToken) + { + KernelParams::ExpertSelectPolicy::template apply( + warp, warpTopKScore, warpTopKExpertIdx, laneIdx, params.mNumExperts, params.mTopK, + params.mPtrScores + scoreOffset, params); + if (laneIdx < params.mTopK) + { + smemPackedScoreIdx[warpIdx * params.mTopK + laneIdx] + = TypePacked{warpTopKScore[laneIdx], static_cast(warpTopKExpertIdx[laneIdx])}; + } + } + } + + __cluster_barrier_arrive(); + __cluster_barrier_wait(); + + if (params.mPtrScores != nullptr) + { + routingPermutation(params, smemPackedScoreIdx, warpIdx, clusterBlockRank); + } + else + { + routingPermutation(params, smemPackedScoreIdx, warpIdx, clusterBlockRank); + } +} +#else +__global__ void __launch_bounds__(NumThreads) routingIndicesClusterKernel(KernelParams /* params */) +{ + assert(false && "routingIndicesClusterKernel is only supported on SM90+ architectures"); +} +#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + +void launchClusterKernel(Data const& data, void* stream) +{ + LAUNCH_ROUTING_CUSTOM(data, false, routingIndicesClusterKernel, NumBlocksPerCluster, NumThreads, + /*smemSize=*/0, // No dynamic smem + stream); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// 3. HistogramScores kernel — computes TopK from raw scores and initializes expert counts. +// Used as step 1 of the multi-kernel pipeline when input is raw logits. +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelParams::MaxNumExperts : 1024) + routingIndicesHistogramScoresKernel(KernelParams params) +{ + using OutputT = typename KernelParams::OutputT; + using InputT = typename KernelParams::InputT; + using BaseType = typename KernelParams::ExpertSelectPolicy::template BaseType; + // Cap actual thread count at 1024 when MaxNumExperts > 1024. + static constexpr int NumThreadsBlock = KernelParams::MaxNumExperts <= 1024 ? KernelParams::MaxNumExperts : 1024; + + // VecSize stays based on MaxNumExperts — each warp still processes all experts for one token. + static constexpr int VecSize = KernelParams::MaxNumExperts / WarpSize; + + int32_t const laneIdx = cutlass::arch::LaneId(); + int32_t const warpIdx = threadIdx.x / WarpSize; + // Use NumThreadsBlock (actual thread count) for grid-stride warp/thread addressing + int32_t const globalWarpIdx = blockIdx.x * NumThreadsBlock / WarpSize + warpIdx; + int32_t const globalWarpStride = gridDim.x * NumThreadsBlock / WarpSize; + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + // Wait on primary grid. + if (params.mUsePdl) + { + cudaGridDependencySynchronize(); + } +#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + + // initialize the mPtrExpertCounts — use NumThreadsBlock for grid-stride + int32_t expertCountsNum = 2 * params.mNumExperts; + int32_t globalThreadIdx = blockIdx.x * NumThreadsBlock + threadIdx.x; + int32_t globalThreadStride = gridDim.x * NumThreadsBlock; + initArr(globalThreadIdx, expertCountsNum, globalThreadStride, params.mPtrExpertCounts, 0); + + // in this case, each warp represents a token, and we use a grid-stride loop + // over all warps/tokens + BaseType warpTopKScore[KernelParams::MaxNumTopExperts]; + int32_t warpTopKExpertIdx[KernelParams::MaxNumTopExperts]; + for (int tokenIdx = globalWarpIdx; tokenIdx < params.mNumTokens; tokenIdx += globalWarpStride) + { + auto scoreOffset = tokenIdx * params.mNumExperts; + + KernelParams::ExpertSelectPolicy::template apply( + warp, warpTopKScore, warpTopKExpertIdx, laneIdx, params.mNumExperts, params.mTopK, + params.mPtrScores + scoreOffset, params); + + if (laneIdx < params.mTopK) + { + PackedScoreIdx packedScore{ + static_cast(warpTopKScore[laneIdx]), static_cast(warpTopKExpertIdx[laneIdx])}; + params.mPtrTopKPacked[tokenIdx * params.mTopK + laneIdx] = packedScore; + } + } + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + // Trigger secondary kernel AFTER writing all packed scores, so the next kernel + // (routingIndicesHistogramKernel) sees the completed mPtrTopKPacked writes. + if (params.mUsePdl) + { + cudaTriggerProgrammaticLaunchCompletion(); + } +#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) +} + +static void launchHistogramScoresKernel(Data const& data, uint32_t maxNumBlocks, uint32_t numThreadsHist, void* stream) { - if (data.mNumTokens <= DynBlockKernelMaxNumTokens && data.mNumExperts <= DynBlockKernelMaxNumExperts) + LAUNCH_ROUTING_CUSTOM(data, false, routingIndicesHistogramScoresKernel, maxNumBlocks, numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// 4. Coop kernel — cooperative histogram + offsets via grid-sync. +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void launchCoopKernel(Data const& data, int numBlocksCoop, uint32_t numThreadsHist, void* stream) +{ + if (data.mNumExperts <= NumExperts128Experts) + { + LAUNCH_ROUTING_WITH_POLICIES(data, /*coopLaunch=*/true, routingIndicesCoopKernel, numBlocksCoop, numThreadsHist, + /*smemSize=*/0, stream, NoOpPreprocess, NoOpPostprocess, NumExperts128Experts, NumTop8Experts); + } + else if (data.mNumExperts <= NumExperts160Experts) { - int32_t const maxExperts = getMaxNumExperts(data.mNumExperts); - int const numSlots = data.mNumTokens * maxExperts; - int const smemSize - = numSlots + numSlots * 2 + 128 + 2 * (maxExperts / WarpSize) * static_cast(sizeof(int32_t)); - int const threads = std::min(std::max(data.mNumTokens * static_cast(WarpSize), maxExperts), 1024); + LAUNCH_ROUTING_WITH_POLICIES(data, /*coopLaunch=*/true, routingIndicesCoopKernel, numBlocksCoop, numThreadsHist, + /*smemSize=*/0, stream, NoOpPreprocess, NoOpPostprocess, NumExperts160Experts, NumTop8Experts); + } + else if (data.mNumExperts <= NumExperts256Experts) + { + LAUNCH_ROUTING_WITH_POLICIES(data, /*coopLaunch=*/true, routingIndicesCoopKernel, numBlocksCoop, numThreadsHist, + /*smemSize=*/0, stream, NoOpPreprocess, NoOpPostprocess, NumExperts256Experts, NumTop8Experts); + } + else if (data.mNumExperts <= NumExperts384Experts) + { + LAUNCH_ROUTING_WITH_POLICIES(data, /*coopLaunch=*/true, routingIndicesCoopKernel, numBlocksCoop, numThreadsHist, + /*smemSize=*/0, stream, NoOpPreprocess, NoOpPostprocess, NumExperts384Experts, NumTop8Experts); + } + else if (data.mNumExperts <= NumExperts512Experts) + { + LAUNCH_ROUTING_WITH_POLICIES(data, /*coopLaunch=*/true, routingIndicesCoopKernel, numBlocksCoop, numThreadsHist, + /*smemSize=*/0, stream, NoOpPreprocess, NoOpPostprocess, NumExperts512Experts, NumTop8Experts); + } + else if (data.mNumExperts <= NumExperts576Experts) + { + LAUNCH_ROUTING_WITH_POLICIES(data, /*coopLaunch=*/true, routingIndicesCoopKernel, numBlocksCoop, numThreadsHist, + /*smemSize=*/0, stream, NoOpPreprocess, NoOpPostprocess, NumExperts576Experts, NumTop8Experts); + } + else + { + TLLM_LOG_ERROR("Coop kernel does not support numExperts > %d", NumExperts576Experts); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// 5-7. Launch wrappers for shared kernels (defined in RoutingKernel.cuh): +// - InitExpertCounts (zero expert counts) +// - Histogram kernel (histogram from packed TopK) +// - Offsets kernel (prefix-scan + permutation) +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void launchInitExpertCounts(Data const& data, uint32_t numThreadsHist, void* stream) +{ + LAUNCH_ROUTING_CUSTOM_NO_POLICY(data, false, routingInitExpertCounts, + (2 * data.mNumExperts - 1) / numThreadsHist + 1, numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream); +} - LAUNCH_ROUTING_RENORMALIZE( - data, false, routingIndicesDynBlockKernel, 1, threads, smemSize, stream, data.mDoSoftmaxBeforeTopK); +void launchHistogramKernel(Data const& data, int numBlocksHistogram, uint32_t numThreadsHist, void* stream) +{ + LAUNCH_ROUTING_CUSTOM_NO_POLICY(data, false, routingIndicesHistogramKernel, numBlocksHistogram, numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream); +} + +void launchOffsetsKernel(Data const& data, int numBlocksOffsets, uint32_t numThreadsHist, void* stream) +{ + LAUNCH_ROUTING_CUSTOM_NO_POLICY(data, false, routingIndicesOffsetsKernel, numBlocksOffsets, numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Entry point +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void run(Data const& data, void* stream) +{ + TLLM_CHECK_WITH_INFO(data.mPtrTopKPacked != nullptr || data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr, + "Routing kernel requires at least one input parameter"); + + // When topK is already computed (mPtrTopKIds or mPtrTopKPacked without scores), + // delegate to the shared post-topK pipeline which handles all path selection + // (single-block, single-cluster, coop, multi-kernel) automatically. + // No routing-method-specific logic needed. + if (data.mPtrTopKIds != nullptr || (data.mPtrTopKPacked != nullptr && data.mPtrScores == nullptr)) + { + if (data.mPtrTopKIds != nullptr) + { + TLLM_CHECK_WITH_INFO(data.mPtrTopKWeights != nullptr, + "When mPtrTopKIds is provided, mPtrTopKWeights must also be provided for custom routing."); + } + uint32_t const numThreadsHist = min(1024, getMaxNumExperts(data.mNumExperts)); + runPostTopKPipeline(data, numThreadsHist, stream); + return; + } + + // After this point, input is mPtrScores (raw logits that need topK computation). + TLLM_CHECK_WITH_INFO(data.mPtrScores != nullptr, "Expected mPtrScores to be non-null at this point."); + TLLM_CHECK_WITH_INFO(data.mPtrPermutedIdxSize != nullptr && data.mPtrCtaIdxXyToBatchIdx != nullptr + && data.mPtrCtaIdxXyToMnLimit != nullptr && data.mPtrNumNonExitingCtas != nullptr, + "Custom routing kernel expects permuted idx and grouped Gemm launch config buffers"); + TLLM_CHECK_WITH_INFO(data.mTopK <= MaxSupportedTopExperts, "Routing kernel expects topK experts <= %d, got %d", + MaxSupportedTopExperts, data.mTopK); + TLLM_CHECK_WITH_INFO(data.mNumExperts <= MaxSupportedExperts, + "Routing kernel expects #experts %d to be no more than %d", data.mNumExperts, MaxSupportedExperts); + TLLM_CHECK_WITH_INFO( + data.mNumExperts % 4 == 0, "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts); + + static int const smMajor = tensorrt_llm::common::getSMVersion() / 10; + + bool const useStaticBlock = data.mNumTokens <= BlockKernelMaxNumTokens; + bool const useDynBlock = !useStaticBlock && data.mNumTokens <= DynBlockKernelMaxNumTokens + && data.mNumExperts <= DynBlockKernelMaxNumExperts; + bool const useSingleBlock = useStaticBlock || useDynBlock; + bool const useSingleCluster = (smMajor >= 9) && (data.mNumTokens <= MaxNumTokensSingleClusterScores); + + if (!useSingleCluster && !useSingleBlock) + { + TLLM_CHECK_WITH_INFO( + data.mPtrTopKPacked != nullptr, "When #tokens is large, `mPtrTopKPacked` is a required input."); + TLLM_CHECK_WITH_INFO( + data.mPtrExpertCounts != nullptr, "When #tokens is large, `mPtrExpertCounts` is a required input."); + } + + uint32_t const numThreadsHist = min(1024, getMaxNumExperts(data.mNumExperts)); + + // Last routing kernel: disable programmaticStreamSerializationAllowed so GEMM waits. + Data lastKernelData = data; + lastKernelData.mPdlAllowOverlap = false; + + if (useDynBlock) + { + launchDynBlockKernel(lastKernelData, numThreadsHist, stream); + } + else if (useStaticBlock) + { + launchBlockKernel(lastKernelData, numThreadsHist, stream); + } + else if (useSingleCluster) + { + launchClusterKernel(lastKernelData, stream); } else { - LAUNCH_ROUTING_RENORMALIZE(data, false, routingIndicesBlockKernel, 1, numThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream, data.mDoSoftmaxBeforeTopK); + uint32_t const maxNumBlocks = 1024; + + launchHistogramScoresKernel(data, maxNumBlocks, numThreadsHist, stream); + + bool const canUseCoop = (smMajor >= 9) && (data.mNumExperts <= 1024) && (data.mPtrPermutedIdxSize != nullptr); + bool useCoop = false; + int numBlocksCoop = 0; + + if (canUseCoop) + { + static int const smCount = tensorrt_llm::common::getMultiProcessorCount(); + numBlocksCoop = smCount - 8; + int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK; + useCoop = (data.mNumTokens <= maxTokensCoop); + } + + if (useCoop) + { + launchInitExpertCounts(data, numThreadsHist, stream); + launchCoopKernel(lastKernelData, numBlocksCoop, numThreadsHist, stream); + } + else + { + uint32_t const expandedIdxSize = data.mNumTokens * data.mTopK; + uint32_t const histogramEltsPerBlock = 8 * numThreadsHist; + uint32_t const offsetEltsPerBlock = NumEltsPerOffsetTilePerThread * numThreadsHist; + + int const numBlocksHistogram + = std::min((expandedIdxSize + histogramEltsPerBlock - 1) / histogramEltsPerBlock, maxNumBlocks); + int const numBlocksOffsets + = std::min((expandedIdxSize + offsetEltsPerBlock - 1) / offsetEltsPerBlock, maxNumBlocks); + + launchHistogramKernel(data, numBlocksHistogram, numThreadsHist, stream); + launchOffsetsKernel(lastKernelData, numBlocksOffsets, numThreadsHist, stream); + } } } //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace routingRenormalize +} // namespace routingCustom } // namespace moe::dev::routing diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingCustomPolicy.cuh b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingCustomPolicy.cuh new file mode 100644 index 00000000000..4fe4c48b328 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingCustomPolicy.cuh @@ -0,0 +1,789 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "RoutingKernel.cuh" + +namespace moe::dev::routing +{ + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Preprocess policies: applied to all expert scores BEFORE topK selection. +// +// Each policy must provide: +// - template using BaseType +// The data type used for intermediate score computation. +// - template struct Params { void set(Data const&); } +// Policy-specific runtime data, populated from the host-side Data struct. +// Empty for policies that don't need extra data (zero register cost). +// - template +// static void apply(warp, score[VecSize], idx[VecSize], numExperts, params) +// Transforms scores in-place before topK selection. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// No-op: scores are passed through unchanged. +struct NoOpPreprocess +{ + /// BaseType: when no preprocess is applied, use the input type directly. + template + using BaseType = InputT; + + template + struct Params + { + void set(routingCustom::Data const& /*data*/) {} + }; + + template + __forceinline__ __device__ static void apply(cg::thread_block_tile const& /*warp*/, + DataType (& /*score*/)[VecSize], int32_t const (& /*idx*/)[VecSize], int32_t /*numExperts*/, + ParamsT const& /*params*/) + { + } +}; + +/// Softmax: applies softmax over all expert scores before topK selection. +struct SoftmaxPreprocess +{ + /// BaseType: softmax is always computed in float for numerical stability. + template + using BaseType = float; + + template + struct Params + { + void set(routingCustom::Data const& /*data*/) {} + }; + + template + __forceinline__ __device__ static void apply(cg::thread_block_tile const& warp, + DataType (&score)[VecSize], int32_t const (& /*idx*/)[VecSize], int32_t /*numExperts*/, + ParamsT const& /*params*/) + { + calcSoftmax(warp, score); + } +}; + +/// Sigmoid: applies sigmoid(score) for topK selection (no bias). +/// Used by Cohere-style routing where expert selection is based on raw sigmoid scores. +struct SigmoidPreprocess +{ + /// BaseType: sigmoid is computed in float for numerical stability. + template + using BaseType = float; + + template + struct Params + { + void set(routingCustom::Data const& /*data*/) {} + }; + + template + __forceinline__ __device__ static void apply(cg::thread_block_tile const& /*warp*/, + DataType (&score)[VecSize], int32_t const (&idx)[VecSize], int32_t numExperts, ParamsT const& /*params*/) + { +#pragma unroll + for (int i = 0; i < VecSize; i++) + { + float s = sigmoid_accurate(static_cast(score[i])); + score[i] = idx[i] < numExperts ? static_cast(s) : DataType{-INFINITY}; + } + } +}; + +/// SigmoidBias: applies sigmoid(score) + bias[expertIdx] for topK selection. +/// Used by DeepSeek-style routing where expert selection is based on biased sigmoid scores. +struct SigmoidBiasPreprocess +{ + /// BaseType: sigmoid is computed in float for numerical stability. + template + using BaseType = float; + + template + struct Params + { + // Store as void const* to support any bias dtype (float, bfloat16, etc.) without conversion. + void const* ptrRoutingBias = nullptr; + batchedGemm::trtllm::gen::Dtype dtypeBias = batchedGemm::trtllm::gen::Dtype::Bfloat16; + + void set(routingCustom::Data const& data) + { + ptrRoutingBias = data.mPtrRoutingBias; + dtypeBias = data.mDtypeBias; + } + }; + + template + __forceinline__ __device__ static void apply(cg::thread_block_tile const& /*warp*/, + DataType (&score)[VecSize], int32_t const (&idx)[VecSize], int32_t numExperts, ParamsT const& params) + { +#pragma unroll + for (int i = 0; i < VecSize; i++) + { + float s = sigmoid_accurate(static_cast(score[i])); + float bias + = idx[i] < numExperts ? loadScalar(params.ptrRoutingBias, idx[i], params.dtypeBias) : float{-INFINITY}; + score[i] = static_cast(s + bias); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Postprocess policies: applied to the top-K scores AFTER topK selection. +// +// Each policy must provide: +// - template struct Params { void set(Data const&); } +// Policy-specific runtime data. Empty when not needed. +// - template +// static void apply(warp, warpTopKScore[K], warpTopKExpertIdx[K], laneIdx, topK, params) +// Transforms top-K scores in-place after topK selection. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// No-op: top-K scores are left unchanged. +struct NoOpPostprocess +{ + template + struct Params + { + void set(routingCustom::Data const& /*data*/) {} + }; + + template + __forceinline__ __device__ static void apply(cg::thread_block_tile const& /*warp*/, + DataType (& /*warpTopKScore*/)[K], int32_t const (& /*warpTopKExpertIdx*/)[K], int32_t /*laneIdx*/, + int32_t /*topK*/, ParamsT const& /*params*/) + { + } +}; + +/// Softmax: applies softmax over the top-K scores. +struct SoftmaxPostprocess +{ + template + struct Params + { + void set(routingCustom::Data const& /*data*/) {} + }; + + template + __forceinline__ __device__ static void apply(cg::thread_block_tile const& warp, + DataType (&warpTopKScore)[K], int32_t const (& /*warpTopKExpertIdx*/)[K], int32_t laneIdx, int32_t topK, + ParamsT const& /*params*/) + { + DataType minScore = DataType{-INFINITY}; + auto softmaxScore = calcSoftmax(warp, laneIdx < topK ? warpTopKScore[laneIdx] : minScore, laneIdx, topK); + if (laneIdx < topK) + { + warpTopKScore[laneIdx] = softmaxScore; + } + } +}; + +/// SumNormalize: divides each top-K score by the sum of all top-K scores. +/// Used when softmax has already been applied before topK selection. +struct SumNormalizePostprocess +{ + template + struct Params + { + bool normTopkProb = true; + + void set(routingCustom::Data const& data) + { + normTopkProb = data.mNormTopkProb; + } + }; + + template + __forceinline__ __device__ static void apply(cg::thread_block_tile const& warp, + DataType (&warpTopKScore)[K], int32_t const (& /*warpTopKExpertIdx*/)[K], int32_t laneIdx, int32_t topK, + ParamsT const& params) + { + float sum = float{1.f}; + if (params.normTopkProb) + { + sum = static_cast(laneIdx < topK ? warpTopKScore[laneIdx] : 0); + sum = cg::reduce(warp, sum, cg::plus()); + } + if (laneIdx < topK) + { + warpTopKScore[laneIdx] = warpTopKScore[laneIdx] / sum; + } + } +}; + +/// ScaledSumNormalize: recovers un-biased sigmoid scores by subtracting per-expert bias from the +/// selection scores (sigmoid + bias), then normalizes by sum and applies routeScale. +/// Used by DeepSeek-style routing: final_weight = sigmoid(raw) * routeScale / (sum + epsilon). +/// DeepSeek uses epsilon=0 (no guard); MiniMax2 uses epsilon=1e-20 to prevent division by zero. +struct ScaledSumNormalizePostprocess +{ + template + struct Params + { + // Store as void const* to support any bias dtype (float, bfloat16, etc.) without conversion. + void const* ptrRoutingBias = nullptr; + batchedGemm::trtllm::gen::Dtype dtypeBias = batchedGemm::trtllm::gen::Dtype::Bfloat16; + float routeScale = 1.0f; + float sumEpsilon = 0.0f; + + void set(routingCustom::Data const& data) + { + ptrRoutingBias = data.mPtrRoutingBias; + dtypeBias = data.mDtypeBias; + routeScale = data.mRouteScale; + sumEpsilon = data.mSumEpsilon; + } + }; + + template + __forceinline__ __device__ static void apply(cg::thread_block_tile const& warp, + DataType (&warpTopKScore)[K], int32_t const (&warpTopKExpertIdx)[K], int32_t laneIdx, int32_t topK, + ParamsT const& params) + { + // Recover sigmoid score: selection_score = sigmoid(raw) + bias, so sigmoid = score - bias + float biasVal + = laneIdx < topK ? loadScalar(params.ptrRoutingBias, warpTopKExpertIdx[laneIdx], params.dtypeBias) : 0.f; + float sigmoidScore = laneIdx < topK ? (static_cast(warpTopKScore[laneIdx]) - biasVal) : 0.f; + float sum = cg::reduce(warp, sigmoidScore, cg::plus()); + if (laneIdx < topK) + { + warpTopKScore[laneIdx] + = static_cast(sigmoidScore * params.routeScale / (sum + params.sumEpsilon)); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// ExpertSelectPolicy: encapsulates the entire expert selection logic. +// +// Each policy must provide: +// - template using BaseType +// The data type used for intermediate score computation. +// - template struct Params { void set(Data const&); } +// Policy-specific runtime data, populated from the host-side Data struct. +// Empty for policies that don't need extra data (zero register cost). +// - template +// static void apply(warp, warpTopKScore[K], warpTopKExpertIdx[K], laneIdx, numExperts, topK, +// ptrScores, params) +// Selects the top-K experts and computes their weights. +// +// The default TopKExpertSelect wraps existing PreprocessPolicy + PostprocessPolicy, +// but users can write completely custom policies that bypass the preprocess+topK+postprocess +// pattern (e.g., lookup-table-based expert selection). +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Default ExpertSelectPolicy: preprocess + topK reduction + postprocess. +/// Wraps existing PreprocessPolicy and PostprocessPolicy as internal composition. +template +struct TopKExpertSelect +{ + /// BaseType: delegated to the preprocess policy. + template + using BaseType = typename PreprocessPolicy_::template BaseType; + + /// Params: combines preprocess and postprocess runtime parameters. + template + struct Params + { + typename PreprocessPolicy_::template Params mPreprocessParams; + typename PostprocessPolicy_::template Params mPostprocessParams; + + void set(routingCustom::Data const& data) + { + mPreprocessParams.set(data); + mPostprocessParams.set(data); + } + }; + + /// Selects top-K experts using preprocess → topK reduction → postprocess. + template + __forceinline__ __device__ static void apply(cg::thread_block_tile const& warp, + DataType (&warpTopKScore)[K], int32_t (&warpTopKExpertIdx)[K], int32_t const laneIdx, int32_t const numExperts, + int32_t topK, InputType const* ptrScores, KP const& params) + { + DataType minScore = DataType{-INFINITY}; + DataType score[VecSize]; + int32_t idx[VecSize]; + + for (int i = 0; i < VecSize; i++) + { + auto expertIdx = i * WarpSize + laneIdx; + auto newScore = expertIdx < numExperts ? static_cast(ptrScores[expertIdx]) : minScore; + score[i] = newScore; + idx[i] = expertIdx; + } + + // Apply preprocess (e.g. softmax over all scores, sigmoid + bias, ...) + PreprocessPolicy_::apply(warp, score, idx, numExperts, params.mExpertSelectParams.mPreprocessParams); + + // Get the top-k scores and their corresponding expert indices + topk::reduceTopK(warp, warpTopKScore, warpTopKExpertIdx, score, idx, minScore, topK); + + // Apply postprocess (e.g. renormalize, softmax over top-K, scaled renormalize, ...) + PostprocessPolicy_::apply( + warp, warpTopKScore, warpTopKExpertIdx, laneIdx, topK, params.mExpertSelectParams.mPostprocessParams); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace routingCustom +{ +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Expert-count tiers (must be multiples of WarpSize=32 and of 4). +// Each tier covers all values ≤ the tier constant. +static constexpr int NumExperts128Experts = 128; +static constexpr int NumExperts160Experts = 160; +static constexpr int NumExperts256Experts = 256; +static constexpr int NumExperts384Experts = 384; +static constexpr int NumExperts512Experts = 512; +static constexpr int NumExperts576Experts = 576; +static constexpr int NumExperts1024Experts = 1024; +static constexpr int MaxSupportedExperts = 2048; + +// TopK tiers (must be ≤ WarpSize=32). +static constexpr int NumTop4Experts = 4; +static constexpr int NumTop8Experts = 8; +static constexpr int NumTop16Experts = 16; +static constexpr int NumTop22Experts = 22; +static constexpr int MaxSupportedTopExperts = 32; + +static constexpr int NumThreads = 1024; +static constexpr int NumWarps = NumThreads / WarpSize; + +static constexpr int MaxNumTokensSingleCluster = NumBlocksPerCluster * NumThreads; +static constexpr int MaxNumTokensSingleClusterScores = NumBlocksPerCluster * NumWarps; + +static constexpr int BlockKernelMaxNumTokens = 4; +static constexpr int DynBlockKernelMaxNumTokens = 16; +static constexpr int DynBlockKernelMaxNumExperts = 512; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +int32_t constexpr getMaxNumExperts(int32_t numExperts) +{ + if (numExperts <= NumExperts128Experts) + { + return NumExperts128Experts; + } + else if (numExperts <= NumExperts160Experts) + { + return NumExperts160Experts; + } + else if (numExperts <= NumExperts256Experts) + { + return NumExperts256Experts; + } + else if (numExperts <= NumExperts384Experts) + { + return NumExperts384Experts; + } + else if (numExperts <= NumExperts512Experts) + { + return NumExperts512Experts; + } + else if (numExperts <= NumExperts576Experts) + { + return NumExperts576Experts; + } + else if (numExperts <= NumExperts1024Experts) + { + return NumExperts1024Experts; + } + else if (numExperts <= MaxSupportedExperts) + { + return MaxSupportedExperts; + } + else + { + TLLM_LOG_ERROR("Unsupported numExperts"); + return 0; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// TIER PAIR TYPES — compile-time (MaxNumExperts, MaxNumTopExperts) configuration. +// +// Each Tier declares a supported kernel instantiation. +// TierList, ...> is an ordered list tried from first to last. +// The dispatch picks the FIRST pair where numExperts ≤ E AND topK ≤ K. +// +// Pairs must be sorted so that tighter tiers come first: +// - Sort by E ascending, then by K ascending within equal E. +// - A config (numExperts, topK) always matches the tightest available pair. +// - If the tightest expert tier doesn't have a topK that covers the runtime topK, +// the dispatch falls through to the next larger expert tier that does. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Tier +{ + static constexpr int kExperts = E_; + static constexpr int kTopK = K_; +}; + +template +struct TierList +{ +}; + +// Recursive dispatch: try each tier in order, call `fn` with the first match. +// fn receives (integral_constant, integral_constant) as compile-time args. +// Base case: empty list — no match. +template +inline bool dispatchTierPairs(TierList<>*, Data const& /*data*/, Fn&& /*fn*/) +{ + return false; +} + +// Recursive case: check First, then recurse on Rest... +template +inline bool dispatchTierPairs(TierList*, Data const& data, Fn&& fn) +{ + if (data.mNumExperts <= First::kExperts && data.mTopK <= First::kTopK) + { + fn(std::integral_constant{}, std::integral_constant{}); + return true; + } + return dispatchTierPairs(static_cast*>(nullptr), data, std::forward(fn)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// POLICY TIER CONFIGURATION +// +// ┌─────────────────────────────────────────────────────────────────────────┐ +// │ HOW TO ADD A NEW ROUTING POLICY │ +// │ │ +// │ 1. Define PreProc/PostProc structs (above in this file) │ +// │ 2. Add PolicyTraits with tier list (below) │ +// │ 3. Add enum value to RoutingPreprocessType/RoutingPostprocessType │ +// │ in RoutingKernel.h (if new enum needed) │ +// │ 4. Add an else-if branch to dispatchRoutingPolicy() (bottom of file) │ +// │ — LAUNCH_ROUTING_CUSTOM and queryDispatchedMaxExperts │ +// │ automatically pick it up │ +// │ 5. Set the policy in runner.cu for the routing method │ +// └─────────────────────────────────────────────────────────────────────────┘ +// +// PolicyTraits::Pairs declares the supported (expert, topK) pairs. +// Only these pairs are compiled as kernel instantiations. +// To add support for a new model config, add a Tier to the appropriate TierList. +// +// THREAD-COUNT SAFETY: LAUNCH_ROUTING_FOR_POLICY automatically clamps the launch thread +// count to at least min(MaxNumExperts, 1024) from the dispatched tier. This prevents +// mismatches when a policy's smallest tier is larger than getMaxNumExperts() returns for +// the same numExperts (e.g., 72 experts → getMaxNumExperts returns 128, but a policy +// whose smallest tier is 256 would produce MaxNumExperts=256). See the comment on +// LAUNCH_ROUTING_FOR_POLICY for details. +// +// ┌──────────────────────────────────────────────────────────────────────────────┐ +// │ Policy (PreProc + PostProc) Supported pairs │ +// ├──────────────────────────────────────────────────────────────────────────────┤ +// │ Softmax + None (Default) (128,8) │ +// │ None + Softmax (Renormalize) (128,4) (128,8) (160,8) (256,8) │ +// │ (256,16) (512,8) (512,16) │ +// │ (512,22) (576,8) (2048,32) │ +// │ Sigmoid + SumNorm (SigmoidRenorm) (128,8) │ +// │ SigmoidBias + ScaleS (DS nGroup≤1) (128,8) (256,8) (384,8) (512,8) │ +// │ (512,22) │ +// │ Softmax + SumNorm (RenormNaive) (128,4) (128,8) (256,8) (512,8) │ +// │ (2048,8) │ +// │ None + None (fallback) (128,8) │ +// └──────────────────────────────────────────────────────────────────────────────┘ +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Default: fallback for new/unknown policies. +/// Provides K8 (tight for most models) + K32 (catch-all for high topK) at each common expert tier. +/// Omits 160/384/576 — those are model-specific and handled by explicit specializations. +/// If a new policy needs a tighter tier, add a PolicyTraits specialization. +template +struct PolicyTraits +{ + using Pairs = TierList, Tier<128, 32>, Tier<256, 8>, Tier<256, 32>, Tier<512, 8>, Tier<512, 32>, + Tier<2048, 8>, Tier<2048, 32>>; +}; + +/// Softmax + None (Default): single config. +template <> +struct PolicyTraits +{ + using Pairs = TierList>; +}; + +/// None + Softmax (Renormalize): many model configs. +template <> +struct PolicyTraits +{ + using Pairs + = TierList, // Mixtral 8x7B (topK=2), Qwen2-MoE (topK=4), Arctic (topK=2), DBRX (topK=4), GPT-OSS + Tier<128, 8>, // DeepSeek-V2-Lite (topK=6), Mixtral 8x22B (topK=2) + Tier<160, 8>, // Qwen3-Coder-480B + Tier<256, 8>, // Mistral Large 3 (topK=8) + Tier<256, 16>, // Models with 256 experts and topK 9..16 + Tier<512, 8>, // Various 512-expert models + Tier<512, 16>, // Various 512-expert models with high topK + Tier<512, 22>, // Nemotron Super V3 (512 experts, topK=22) + Tier<576, 8>, // Customized model with 576 experts + Tier<2048, 32> // Large-expert fallback + >; +}; + +/// Sigmoid + SumNormalize (SigmoidRenorm, Cohere): single config. +template <> +struct PolicyTraits +{ + using Pairs = TierList>; +}; + +/// SigmoidBias + ScaledSumNormalize (DeepSeek nGroup≤1 / MiniMax2 / Kimi-K2 / Nemotron SuperV3). +template <> +struct PolicyTraits +{ + using Pairs = TierList, // Small expert counts (≤128 experts, e.g. DeepSeek-V2-Lite) + Tier<256, 8>, // MiniMax M2 (256 experts, topK=6) + Tier<384, 8>, // Kimi K2 (384 experts) + Tier<512, 8>, // DeepSeek nGroup≤1 (256 experts → E512 fallback) + Tier<512, 22>, // Nemotron Super V3 (512 experts, topK=22, nGroup≤1) + Tier<1024, 32> // Default fallback (expert count may grow beyond 512) + >; +}; + +/// Softmax + SumNormalize (RenormalizeNaive): no specialization needed. +/// At runtime, RenormalizeNaive is always converted to the Renormalize path +/// (None + Softmax) by the runner, so this policy is never dispatched. +/// If it ever is, the default PolicyTraits provides broad fallback coverage. + +/// None + None (fallback for unknown preprocess/postprocess in LAUNCH_ROUTING_CUSTOM). +template <> +struct PolicyTraits +{ + using Pairs = TierList>; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// EXAMPLE: Custom ExpertSelectPolicy that bypasses the PreProc→topK→PostProc pattern. +// +// To enable it: +// 1. Uncomment the struct and PolicyTraits below. +// 2. Add an enum value (e.g., RoutingPreprocessType::FirstK) in RoutingKernel.h. +// 3. Add a branch in LAUNCH_ROUTING_CUSTOM that calls LAUNCH_ROUTING_FOR_EXPERT_SELECT. +// 4. Set the enum in runner.cu for the desired routing method type. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/* +struct FirstKExpertSelect +{ + template using BaseType = float; + template struct Params { void set(routingCustom::Data const&) {} }; + + template + __forceinline__ __device__ static void apply(cg::thread_block_tile const&, + DataType (&warpTopKScore)[K], int32_t (&warpTopKExpertIdx)[K], int32_t const laneIdx, + int32_t const, int32_t topK, InputType const*, KP const&) + { + if (laneIdx < topK) + { + warpTopKExpertIdx[laneIdx] = laneIdx; + warpTopKScore[laneIdx] = static_cast(1.0f / topK); + } + } +}; + +template <> struct PolicyTraits +{ + using Pairs = TierList>; +}; +*/ + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// GENERIC DISPATCH MACROS +// +// These macros are fixed infrastructure — they never need editing when adding new +// policies or changing tier support. All configuration lives in PolicyTraits above. +// +// The dispatch iterates PolicyTraits::Pairs (a TierList) via dispatchTierPairs. +// A generic lambda captures the kernel name (macro requirement) and receives +// (expert, topK) as compile-time integral_constants. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Generic per-policy dispatch. Iterates PolicyTraits::Pairs, +// picking the first (expert, topK) pair that covers the runtime values. +// +// IMPORTANT: numThreads is clamped to at least min(MaxNumExperts, 1024) from the dispatched tier. +// Many routing kernels derive their internal NumThreadsBlock from MaxNumExperts and use it for +// grid-stride addressing, initArr strides, and cub::BlockScan. If the caller's numThreads +// (typically getMaxNumExperts(mNumExperts)) is smaller than the tier's MaxNumExperts, the kernel +// would compute wrong indices, skip initialization, and corrupt memory. The max() below +// guarantees the launch thread count always matches or exceeds the kernel's NumThreadsBlock: +// - "derive from tier" kernels: numThreadsHist < MaxNumExperts → bumped to MaxNumExperts ✓ +// - "fixed 1024" kernels (cluster): numThreads=1024 ≥ MaxNumExperts → unchanged ✓ +#define LAUNCH_ROUTING_FOR_POLICY( \ + data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, PreProc, PostProc) \ + [&](auto pt_tag_) \ + { \ + using Pairs_ = typename decltype(pt_tag_)::Pairs; \ + bool dispatched_ = dispatchTierPairs(static_cast(nullptr), data, \ + [&](auto eTag_, auto kTag_) \ + { \ + constexpr int tierMaxExp_ = decltype(eTag_)::value; \ + constexpr int tierThreads_ = tierMaxExp_ <= 1024 ? tierMaxExp_ : 1024; \ + int const effectiveThreads_ = std::max(static_cast(numThreads), tierThreads_); \ + LAUNCH_ROUTING_WITH_POLICIES(data, coopLaunch, kernel, numBlocks, effectiveThreads_, smemSize, stream, \ + PreProc, PostProc, decltype(eTag_)::value, decltype(kTag_)::value); \ + }); \ + if (!dispatched_) \ + { \ + TLLM_LOG_ERROR("No tier covers numExperts=%d topK=%d", data.mNumExperts, data.mTopK); \ + } \ + }(PolicyTraits{}) + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// CUSTOM EXPERT SELECT DISPATCH +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Generic dispatch for custom ExpertSelectPolicy. PolicyTraits key is . +// Same numThreads clamping as LAUNCH_ROUTING_FOR_POLICY — see comment above. +#define LAUNCH_ROUTING_FOR_EXPERT_SELECT( \ + data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, ExpertSelect) \ + [&](auto pt_tag_) \ + { \ + using Pairs_ = typename decltype(pt_tag_)::Pairs; \ + bool dispatched_ = dispatchTierPairs(static_cast(nullptr), data, \ + [&](auto eTag_, auto kTag_) \ + { \ + constexpr int tierMaxExp_ = decltype(eTag_)::value; \ + constexpr int tierThreads_ = tierMaxExp_ <= 1024 ? tierMaxExp_ : 1024; \ + int const effectiveThreads_ = std::max(static_cast(numThreads), tierThreads_); \ + LAUNCH_ROUTING_WITH_EXPERT_SELECT(data, coopLaunch, kernel, numBlocks, effectiveThreads_, smemSize, \ + stream, ExpertSelect, decltype(eTag_)::value, decltype(kTag_)::value); \ + }); \ + if (!dispatched_) \ + { \ + TLLM_LOG_ERROR("No tier covers numExperts=%d topK=%d", data.mNumExperts, data.mTopK); \ + } \ + }(PolicyTraits{}) + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// PUBLIC DISPATCH MACROS +// +// These are the only macros that call sites use. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Lightweight dispatch for utility kernels (histogram, init-counts, offsets) that do NOT use +// expert select policies, InputT, or MaxNumTopExperts. +// - Always uses NoOp expert select (no policy dispatch). +// - Always uses a fixed NumTop8Experts (no topK-tier dispatch). +// - Dispatches only on expert tiers. +// This is intentionally NOT routed through LAUNCH_ROUTING_FOR_POLICY to avoid +// instantiating all topK tiers — utility kernels don't use topK at all. +#define LAUNCH_ROUTING_CUSTOM_NO_POLICY(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream) \ + if (data.mNumExperts <= NumExperts128Experts) \ + { \ + LAUNCH_ROUTING_WITH_POLICIES(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \ + NoOpPreprocess, NoOpPostprocess, NumExperts128Experts, NumTop8Experts); \ + } \ + else if (data.mNumExperts <= NumExperts160Experts) \ + { \ + LAUNCH_ROUTING_WITH_POLICIES(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \ + NoOpPreprocess, NoOpPostprocess, NumExperts160Experts, NumTop8Experts); \ + } \ + else if (data.mNumExperts <= NumExperts256Experts) \ + { \ + LAUNCH_ROUTING_WITH_POLICIES(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \ + NoOpPreprocess, NoOpPostprocess, NumExperts256Experts, NumTop8Experts); \ + } \ + else if (data.mNumExperts <= NumExperts384Experts) \ + { \ + LAUNCH_ROUTING_WITH_POLICIES(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \ + NoOpPreprocess, NoOpPostprocess, NumExperts384Experts, NumTop8Experts); \ + } \ + else if (data.mNumExperts <= NumExperts512Experts) \ + { \ + LAUNCH_ROUTING_WITH_POLICIES(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \ + NoOpPreprocess, NoOpPostprocess, NumExperts512Experts, NumTop8Experts); \ + } \ + else if (data.mNumExperts <= NumExperts576Experts) \ + { \ + LAUNCH_ROUTING_WITH_POLICIES(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \ + NoOpPreprocess, NoOpPostprocess, NumExperts576Experts, NumTop8Experts); \ + } \ + else if (data.mNumExperts <= NumExperts1024Experts) \ + { \ + LAUNCH_ROUTING_WITH_POLICIES(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \ + NoOpPreprocess, NoOpPostprocess, NumExperts1024Experts, NumTop8Experts); \ + } \ + else if (data.mNumExperts <= MaxSupportedExperts) \ + { \ + LAUNCH_ROUTING_WITH_POLICIES(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \ + NoOpPreprocess, NoOpPostprocess, MaxSupportedExperts, NumTop8Experts); \ + } \ + else \ + { \ + TLLM_LOG_ERROR("Unsupported numExperts"); \ + } + +// Single source of truth for runtime → compile-time policy dispatch. +// Maps (mPreprocessType, mPostprocessType) to compile-time (PreProc, PostProc) policy types. +// The callback `fn` receives instances of the policy types (e.g., SigmoidBiasPreprocess{}). +// Both LAUNCH_ROUTING_CUSTOM and queryDispatchedMaxExperts use this function, +// so they are always in sync. See "HOW TO ADD A NEW ROUTING POLICY" above. +template +inline void dispatchRoutingPolicy(Data const& data, Fn&& fn) +{ + if (data.mPreprocessType == RoutingPreprocessType::SigmoidBias) + fn(SigmoidBiasPreprocess{}, ScaledSumNormalizePostprocess{}); + else if (data.mPreprocessType == RoutingPreprocessType::Sigmoid) + fn(SigmoidPreprocess{}, SumNormalizePostprocess{}); + else if (data.mPreprocessType == RoutingPreprocessType::Softmax + && data.mPostprocessType == RoutingPostprocessType::None) + fn(SoftmaxPreprocess{}, NoOpPostprocess{}); + else if (data.mPreprocessType == RoutingPreprocessType::Softmax) + fn(SoftmaxPreprocess{}, SumNormalizePostprocess{}); + else if (data.mPostprocessType == RoutingPostprocessType::Softmax) + fn(NoOpPreprocess{}, SoftmaxPostprocess{}); + else + fn(NoOpPreprocess{}, NoOpPostprocess{}); +} + +// Query the MaxNumExperts that the policy tier dispatch would select for the given data. +inline int32_t queryDispatchedMaxExperts(Data const& data) +{ + int32_t result = getMaxNumExperts(data.mNumExperts); + dispatchRoutingPolicy(data, + [&](auto preProc, auto postProc) + { + using Pairs = typename PolicyTraits::Pairs; + dispatchTierPairs( + static_cast(nullptr), data, [&](auto eTag, auto /*kTag*/) { result = decltype(eTag)::value; }); + }); + return result; +} + +// Top-level dispatch: maps runtime preprocess/postprocess enums to compile-time policy types, +// then delegates to LAUNCH_ROUTING_FOR_POLICY which reads PolicyTraits for tier support. +#define LAUNCH_ROUTING_CUSTOM(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream) \ + dispatchRoutingPolicy(data, \ + [&](auto preProc_, auto postProc_) \ + { \ + LAUNCH_ROUTING_FOR_POLICY(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \ + decltype(preProc_), decltype(postProc_)); \ + }) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace routingCustom +} // namespace moe::dev::routing diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingDeepSeek.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingDeepSeek.cu new file mode 100644 index 00000000000..f0da3c06e75 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingDeepSeek.cu @@ -0,0 +1,628 @@ +/* + * Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// DeepSeek routing: entry point, constants, dispatch macros, kernel definitions, and launch wrappers. +// +// Kernel inventory: +// 1. routingMainKernel — DeepSeek-specific main kernel (sigmoid + bias + group TopK) +// 2. routingIndicesClusterKernel — single-cluster fused kernel (SM90+) +// 3. launchCoopKernel — delegates to routingCustom's coop implementation +// 4. launchInitExpertCounts — zero expert counts +// 5. launchHistogramKernel — histogram from packed TopK +// 6. launchOffsetsKernel — prefix-scan + permutation + +#include "RoutingCustomPolicy.cuh" +#include "RoutingKernel.cuh" + +namespace moe::dev::routing +{ + +// Forward declaration of routingCustom's coop kernel (used by DeepSeek's coop path) +namespace routingCustom +{ +void launchCoopKernel(Data const& data, int numBlocksCoop, uint32_t numThreadsHist, void* stream); +} // namespace routingCustom + +namespace routingDeepSeek +{ + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Constants and dispatch macros +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static constexpr int NumNemotronExperts = 512; +static constexpr int NumKimiK2Experts = 384; +static constexpr int NumDeepseekExperts = 256; +static constexpr int NumExperts1024Experts = 1024; +static constexpr int MaxSupportedExpertCount = std::max({NumExperts1024Experts, NumNemotronExperts, NumKimiK2Experts, NumDeepseekExperts}); +static constexpr int NumTopGroupScores = 2; +static constexpr int MaxNumTopGroups = 4; +static constexpr int MaxNumGroups = 8; + +static constexpr int NumTop8Experts = 8; +static constexpr int NumTop22Experts = 22; +static constexpr int MaxSupportedTopExperts = 32; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +int constexpr getMaxNumExperts(int32_t numExperts) +{ + if (numExperts <= topk::MaxNumExpertsUnit) + { + return topk::MaxNumExpertsUnit; + } + else if (numExperts <= NumDeepseekExperts) + { + return NumDeepseekExperts; + } + else if (numExperts <= NumKimiK2Experts) + { + return NumKimiK2Experts; + } + else if (numExperts <= NumNemotronExperts) + { + return NumNemotronExperts; + } + else if (numExperts <= NumExperts1024Experts) + { + return NumExperts1024Experts; + } + else + { + TLLM_LOG_ERROR("Unsupported numExperts"); + return 0; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Helper macro: dispatch on topK tier for a given numExperts tier. +#define LAUNCH_DEEPSEEK_WITH_TOPK( \ + data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, forceFloatInput, numExperts) \ + if (data.mTopK <= NumTop8Experts) \ + { \ + LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + stream, extraFlag1, forceFloatInput, numExperts, NumTop8Experts); \ + } \ + else if (data.mTopK <= NumTop22Experts) \ + { \ + LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + stream, extraFlag1, forceFloatInput, numExperts, NumTop22Experts); \ + } \ + else \ + { \ + LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + stream, extraFlag1, forceFloatInput, numExperts, MaxSupportedTopExperts); \ + } + +#define LAUNCH_ROUTING_DEEPSEEK( \ + data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, forceFloatInput) \ + if (data.mNumExperts <= topk::MaxNumExpertsUnit) \ + { \ + LAUNCH_DEEPSEEK_WITH_TOPK(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, \ + forceFloatInput, topk::MaxNumExpertsUnit); \ + } \ + else if (data.mNumExperts <= NumDeepseekExperts) \ + { \ + LAUNCH_DEEPSEEK_WITH_TOPK(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, \ + forceFloatInput, NumDeepseekExperts); \ + } \ + else if (data.mNumExperts <= NumKimiK2Experts) \ + { \ + LAUNCH_DEEPSEEK_WITH_TOPK(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, \ + forceFloatInput, NumKimiK2Experts); \ + } \ + else if (data.mNumExperts <= NumNemotronExperts) \ + { \ + LAUNCH_DEEPSEEK_WITH_TOPK(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, \ + forceFloatInput, NumNemotronExperts); \ + } \ + else if (data.mNumExperts <= NumExperts1024Experts) \ + { \ + LAUNCH_DEEPSEEK_WITH_TOPK(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, \ + forceFloatInput, NumExperts1024Experts); \ + } \ + else \ + { \ + TLLM_LOG_ERROR("Unsupported numExperts"); \ + } + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// 1. Main kernel — DeepSeek-specific routing with sigmoid activation, bias, and group TopK. +// Handles both grouped and non-grouped expert selection. +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void routingMainKernel(KernelParams params) +{ + // declare types + using OutputT = typename KernelParams::OutputT; + using InputT = typename KernelParams::InputT; + + // declare shared memory structure + // number of experts is bounded by number of threads + __shared__ float __attribute((aligned(128))) smemScoreSigmoid[KernelParams::MaxNumExperts]; + __shared__ float __attribute((aligned(128))) smemScoreBias[KernelParams::MaxNumExperts]; + // number of expert groups is bounded by number of warps + __shared__ float __attribute((aligned(128))) smemGroupScores[MaxNumGroups]; + + // needed for warp reduce + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + // for the final reduction of weight norm, only some lanes need to participate + int32_t laneIdx = threadIdx.x % WarpSize; + int32_t warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); + // note that for invalid scores, we simply use a negative value: + // they work well even with the compacted format used in topK, and + // sigmoid / bias activated scores cannot be negative + static constexpr float invalidScoreFloat = float{-INFINITY}; + const OutputT invalidScore = OutputT{invalidScoreFloat}; + + // load bias already; each warp represents one expert group + auto threadExpert = threadIdx.x; + bool expertSelected = threadExpert < params.mNumExperts; + if constexpr (KernelParams::UseGroups) + { + threadExpert = warpIdx * params.mNumExpertsPerGroup + laneIdx; + // Inactive warps (warpIdx >= mNumExpertGroups) must NOT return early because they + // still need to reach the __syncthreads() barriers below. Setting expertSelected + // to false is enough to keep them from doing any out-of-bounds reads or smem writes. + expertSelected = (warpIdx < params.mNumExpertGroups) && (laneIdx < params.mNumExpertsPerGroup); + } + auto scoreIdx = int64_t{blockIdx.x} * int64_t{params.mNumExperts} + threadExpert; + auto biasVal = expertSelected + ? static_cast(loadScalar(params.mPtrRoutingBias, threadExpert, params.mDtypeBias)) + : invalidScore; + + // initialize the mPtrExpertCounts + if (params.mPtrExpertCounts) + { + int32_t globalThreadIdx = blockIdx.x * blockDim.x + threadIdx.x; + int32_t globalThreadStride = gridDim.x * blockDim.x; + int32_t expertCountsNum = 2 * params.mNumExperts; + initArr(globalThreadIdx, expertCountsNum, globalThreadStride, params.mPtrExpertCounts, 0); + } + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + // trigger the secondary kernel when using PDL, then wait on primary + if (params.mUsePdl) + { + cudaTriggerProgrammaticLaunchCompletion(); + cudaGridDependencySynchronize(); + } +#endif + + if (params.mPtrScores != nullptr) + { + // get our assigned thread score; each warp represents one expert group + float score = expertSelected ? static_cast(params.mPtrScores[scoreIdx]) : invalidScoreFloat; + // get the sigmoid score + // note that for invalid values, we simply use a negative value: + // sigmoig scores are always strictly positive + auto scoreSigmoid = sigmoid_accurate(score); + // write the sigmoid score to shared for later use + if (expertSelected) + { + smemScoreSigmoid[threadExpert] = scoreSigmoid; + } + // get the score with bias + // note that with invalid values, because sigmoid is < 1 and bias is -1, + // we must get a negative value, which is smaller than any valid value + auto scoreBias = float{scoreSigmoid + float{biasVal}}; + + if (expertSelected) + { + smemScoreBias[threadExpert] = scoreBias; + } + + // registers for top group score reduction + float topExpGroupScores[NumTopGroupScores]; + [[maybe_unused]] int32_t topExpGroupIdx[NumTopGroupScores]; + float topGroups[MaxNumTopGroups]; // bound of params.mNumLimitedGroups + int32_t topGroupIdx[MaxNumTopGroups]; + float expertScoreGroup[MaxNumTopGroups]; + int32_t expertIdxGroup[MaxNumTopGroups]; + float topScores[KernelParams::MaxNumTopExperts]; // bound of params.mTopK + int32_t topExperts[KernelParams::MaxNumTopExperts]; + + if constexpr (KernelParams::UseGroups) + { + topk::reduceTopK(warp, topExpGroupScores, topExpGroupIdx, scoreBias, threadExpert, + /* minValue */ invalidScoreFloat); + // get the final group score and write it to shared + if (cute::elect_one_sync()) + { + auto groupScore = topExpGroupScores[0] + topExpGroupScores[1]; + smemGroupScores[warpIdx] = groupScore; + } + } + + // make group scores available to all warps + __syncthreads(); + + auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2; + if constexpr (KernelParams::UseGroups) + { // a single warp performs the selection of top groups, and goes on to select the final experts + if (warpIdx == 0) + { + float groupScore = laneIdx < params.mNumExpertGroups ? smemGroupScores[laneIdx] : invalidScoreFloat; + topk::reduceTopK(warp, topGroups, topGroupIdx, groupScore, laneIdx, + /* minValue */ invalidScoreFloat); + // final expert selection: get relevant indexes and scores from shared +#pragma unroll + for (int ii = 0; ii < MaxNumTopGroups; ++ii) + { // bound of params.mNumLimitedGroups + auto groupIdx = topGroupIdx[ii]; + expertIdxGroup[ii] = groupIdx * params.mNumExpertsPerGroup + laneIdx; + // note: expertSelected implies laneIdx < params.mNumExpertsPerGroup. + // we have params.mNumExpertsPerGroup == params.mNumExperts / params.mNumExpertGroups, + // thus groupIdx <= params.mNumExpertGroups - 1 => + // groupIdx * params.mNumExpertsPerGroup <= params.mNumExperts - params.mNumExpertsPerGroup + // => expertIdxGroup[ii] < params.mNumExperts <= NumThreads, + // so the access is safe here + expertScoreGroup[ii] + = (ii < params.mNumLimitedGroups) && (groupIdx < params.mNumExpertGroups) && expertSelected + ? smemScoreBias[expertIdxGroup[ii]] + : invalidScoreFloat; + } + + topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup, expertIdxGroup, + /* minValue */ invalidScoreFloat, params.mTopK); + } + } + else if constexpr (KernelParams::MaxNumExperts > topk::MaxNumExpertsUnit) + { + // without groups, each thread just takes `MaxNumTopGroups` experts + int constexpr NumExpertWarps = (KernelParams::MaxNumExperts - 1) / topk::MaxNumExpertsUnit + 1; + int constexpr NumInterTopK = NumExpertWarps * KernelParams::MaxNumTopExperts; + __shared__ float __attribute((aligned(128))) smemInterTopScores[NumInterTopK]; + __shared__ int32_t __attribute((aligned(128))) smemInterTopExperts[NumInterTopK]; + if (warpIdx < NumExpertWarps) + { + int offset = warpIdx * WarpSize * MaxNumTopGroups; +#pragma unroll + for (int ii = 0; ii < MaxNumTopGroups; ++ii) + { + auto expertIdx = ii * WarpSize + laneIdx; + expertIdxGroup[ii] = offset + expertIdx; + expertScoreGroup[ii] = offset + expertIdx < params.mNumExperts ? smemScoreBias[offset + expertIdx] + : invalidScoreFloat; + } + topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup, expertIdxGroup, + /* minValue */ invalidScoreFloat, params.mTopK); + + if (laneIdx < params.mTopK) + { + smemInterTopScores[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] = topScores[laneIdx]; + smemInterTopExperts[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] = topExperts[laneIdx]; + } + else if (laneIdx >= params.mTopK && laneIdx < KernelParams::MaxNumTopExperts) + { + smemInterTopScores[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] = invalidScoreFloat; + smemInterTopExperts[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] + = MaxSupportedExpertCount - 1; + } + } + __syncthreads(); + if (warpIdx == 0) + { + int constexpr NumInterTopKPerThread = (NumInterTopK - 1) / WarpSize + 1; + float intermediateScore[NumInterTopKPerThread]; + int32_t intermediateExpert[NumInterTopKPerThread]; + for (int i = laneIdx; i < NumInterTopKPerThread * WarpSize; i += WarpSize) + { + int ii = i / WarpSize; + if (i < NumInterTopK) + { + intermediateScore[ii] = smemInterTopScores[i]; + intermediateExpert[ii] = smemInterTopExperts[i]; + } + else + { + intermediateScore[ii] = invalidScoreFloat; + intermediateExpert[ii] = KernelParams::MaxNumExperts - 1; + } + } + topk::reduceTopK(warp, topScores, topExperts, intermediateScore, intermediateExpert, + /* minValue */ invalidScoreFloat, params.mTopK); + } + } + else + { + if (warpIdx == 0) + { + // without groups, each thread just takes `MaxNumTopGroups` experts +#pragma unroll + for (int ii = 0; ii < MaxNumTopGroups; ++ii) + { + auto expertIdx = ii * WarpSize + laneIdx; + expertIdxGroup[ii] = expertIdx; + expertScoreGroup[ii] + = expertIdx < params.mNumExperts ? smemScoreBias[expertIdx] : invalidScoreFloat; + } + topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup, expertIdxGroup, + /* minValue */ invalidScoreFloat, params.mTopK); + } + } + + if (warpIdx == 0) + { + // determine our lane's expert index and write to output + int32_t expertIdx = 0; +#pragma unroll + for (int ii = 0; ii < params.mTopK; ++ii) + { // bound of params.mTopK + expertIdx = laneIdx == ii ? topExperts[ii] : expertIdx; + } + // determine whether our expert is local to this GPU + auto localExpertIdx = expertIdx - params.mLocalExpertsStartIdx; + auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent + && (localExpertIdx & ((1 << params.mLocalExpertsStrideLog2) - 1)) == 0; + + float scoreNorm = laneIdx < params.mTopK ? smemScoreSigmoid[expertIdx] : 0.F; + auto redNorm = cg::reduce(warp, scoreNorm, cg::plus{}); + auto finalScore = OutputT{scoreNorm * params.mRouteScale / redNorm}; + + // write expert idx out already + auto idxTopK = blockIdx.x * params.mTopK + laneIdx; + if (laneIdx < params.mTopK && params.mPtrTopKPacked != nullptr) + { + PackedScoreIdx packedScore{static_cast(finalScore), static_cast(expertIdx)}; + params.mPtrTopKPacked[idxTopK] = packedScore; + } + + if (laneIdx < params.mTopK && params.mPtrTopKWeights != nullptr && params.mPtrTopKIds == nullptr) + { + params.mPtrTopKWeights[idxTopK] = finalScore; + } + } + } +} + +static void launchMainKernel(Data& data, int numBlocks, int numThreadsMain, void* stream) +{ + LAUNCH_ROUTING_DEEPSEEK(data, + /*coopLaunch=*/false, routingMainKernel, numBlocks, numThreadsMain, + /*smemSize=*/0, // No dynamic smem + stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// 2. Cluster kernel — single-cluster fused kernel (SM90+). +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) +__global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(KernelParams::MaxNumExperts) + routingIndicesClusterKernel(KernelParams params) +{ + using OutputT = typename KernelParams::OutputT; + + int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); + int32_t const clusterBlockRank = blockIdx.x; + + //@todo: try to move it into routingPermutation + if (params.mUsePdl) + { + cudaGridDependencySynchronize(); + } + routingPermutation(params, nullptr, warpIdx, clusterBlockRank); +} +#else +__global__ void routingIndicesClusterKernel(KernelParams params) +{ + assert(false && "routingIndicesClusterKernel is only supported on SM90+ architectures"); +} +#endif + +static void launchClusterKernel(Data& data, int numThreadsHist, void* stream) +{ + LAUNCH_ROUTING_DEEPSEEK(data, + /*coopLaunch=*/false, routingIndicesClusterKernel, NumBlocksPerCluster, numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// 3-6. Launch wrappers for shared kernels. +// Coop delegates to routingCustom; others use LAUNCH_ROUTING_DEEPSEEK macro. +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static void launchCoopKernel(Data& data, int numBlocksCoop, int /*numThreadsHist*/, void* stream) +{ + // Use routingCustom's coop kernel implementation (they are identical). + // Convert DeepSeek Data to Custom Data for launching. + routingCustom::Data customData; + // Copy base fields + static_cast(customData) = static_cast(data); + // Set routingCustom-specific defaults (not needed for coop kernel) + customData.mDtypeOutput = data.mDtypeOutput; + // The coop kernel doesn't read routing logits (mPtrInput), only mPtrTopKPacked. + // Set mDtypeInput = mDtypeOutput so the dispatched template is , + // avoiding an unnecessary mixed-type instantiation. + customData.mDtypeInput = data.mDtypeOutput; + customData.mPreprocessType = RoutingPreprocessType::None; + customData.mPostprocessType = RoutingPostprocessType::Softmax; + + // Recompute numThreadsHist using routingCustom's expert tiers (128, 512, 2048), + // since the custom coop kernel dispatch selects template parameters based on these tiers. + // DeepSeek's getMaxNumExperts uses different tiers (256, 384, 512) which would mismatch. + uint32_t const customNumThreadsHist + = std::min(1024u, static_cast(routingCustom::getMaxNumExperts(data.mNumExperts))); + routingCustom::launchCoopKernel(customData, numBlocksCoop, customNumThreadsHist, stream); +} + +static void launchHistogramKernel(Data& data, int numBlocksHistogram, int numThreadsHist, void* stream) +{ + LAUNCH_ROUTING_DEEPSEEK(data, + /*coopLaunch=*/false, routingIndicesHistogramKernel, numBlocksHistogram, numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); +} + +static void launchOffsetsKernel(Data& data, int numBlocksOffsets, int numThreadsHist, void* stream) +{ + LAUNCH_ROUTING_DEEPSEEK(data, + /*coopLaunch=*/false, routingIndicesOffsetsKernel, numBlocksOffsets, numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Entry point +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void run(Data& data, void* stream) +{ + TLLM_CHECK_WITH_INFO(data.mPtrTopKPacked != nullptr || data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr, + "Routing kernel requires at least one input parameter"); + + // When topK is already computed (mPtrTopKIds or mPtrTopKPacked without scores), + // delegate to the shared post-topK pipeline which handles all path selection + // (single-block, single-cluster, coop, multi-kernel) automatically. + // No routing-method-specific logic needed. + if (data.mPtrTopKIds != nullptr || (data.mPtrTopKPacked != nullptr && data.mPtrScores == nullptr)) + { + if (data.mPtrTopKIds != nullptr) + { + TLLM_CHECK_WITH_INFO(data.mPtrTopKWeights != nullptr, + "When mPtrTopKIds is provided, mPtrTopKWeights must also be provided for DeepSeek routing."); + } + int const numThreadsHist = getMaxNumExperts(data.mNumExperts); + runPostTopKPipeline(data, numThreadsHist, stream); + return; + } + + // After this point, input is mPtrScores (raw logits that need DeepSeek-specific routing). + TLLM_CHECK_WITH_INFO(!data.mUseRoutingSoftmax, "Routing with softmax not implemented yet"); + TLLM_CHECK_WITH_INFO(data.mNumExperts >= data.mTopK, "Routing kernel expects topK (%d) to be <= numExperts (%d)", + data.mTopK, data.mNumExperts); + TLLM_CHECK_WITH_INFO(data.mNumExperts <= MaxSupportedExpertCount, + "Routing kernel expects #experts %d <= #threads %d", data.mNumExperts, MaxSupportedExpertCount); + TLLM_CHECK_WITH_INFO(data.mTopK <= MaxSupportedTopExperts, "Routing kernel expects topK experts <= %d, got %d", + MaxSupportedTopExperts, data.mTopK); + + if (data.mPtrExpandedIdxToPermutedIdx != nullptr || data.mPtrPermutedIdxToExpandedIdx != nullptr + || data.mPtrPermutedIdxToTokenIdx != nullptr) + TLLM_CHECK_WITH_INFO(data.mPtrTopKPacked != nullptr && data.mPtrPermutedIdxSize, + "If permuted index is required, `mPtrTopKPacked` is also required"); + + // Routing needs to be executed - validate routing kernel constraints + if (data.mNumExpertGroups > 1) + { + TLLM_CHECK_WITH_INFO(data.mNumExpertGroups <= MaxNumGroups, + "Routing kernel expects #expert groups %d to be <= max groups %d", data.mNumExpertGroups, MaxNumGroups); + TLLM_CHECK_WITH_INFO(data.mNumExperts % data.mNumExpertGroups == 0, + "Routing kernel expects #experts %d to be a multiple of #expert groups %d", data.mNumExperts, + data.mNumExpertGroups); + TLLM_CHECK_WITH_INFO(data.mNumExperts / data.mNumExpertGroups <= WarpSize, + "Routing kernel expects #experts per group <= warp size (%d), got %d experts / %d groups = %d experts " + "per group", + WarpSize, data.mNumExperts, data.mNumExpertGroups, data.mNumExperts / data.mNumExpertGroups); + TLLM_CHECK_WITH_INFO(data.mNumLimitedGroups <= MaxNumTopGroups, + "Routing kernel expects <= %d top groups, got %d", MaxNumTopGroups, data.mNumLimitedGroups); + + TLLM_CHECK_WITH_INFO(data.mNumExpertGroups >= data.mNumLimitedGroups, + "Routing kernel expects top groups %d to be limited by #expert groups %d", data.mNumLimitedGroups, + data.mNumExpertGroups); + TLLM_CHECK_WITH_INFO( + data.mNumExperts % 4 == 0, "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts); + } + + int const numBlocks = data.mNumTokens; + int const numThreadsHist = getMaxNumExperts(data.mNumExperts); + static int const smMajor = tensorrt_llm::common::getSMVersion() / 10; + // Step 1: Run DeepSeek-specific topK computation (writes to mPtrTopKPacked) + int const numThreadsMain = max(data.mNumExpertGroups * WarpSize, getMaxNumExperts(data.mNumExperts)); + launchMainKernel(data, numBlocks, numThreadsMain, stream); + + // Step 2: Permutation pipeline (reads from mPtrTopKPacked written by step 1) + if (data.mPtrPermutedIdxSize != nullptr) + { + TLLM_CHECK_WITH_INFO(data.mPtrCtaIdxXyToBatchIdx != nullptr && data.mPtrCtaIdxXyToMnLimit != nullptr + && data.mPtrNumNonExitingCtas != nullptr, + "DeepSeek routing step 2 requires grouped-GEMM launch config buffers " + "(mPtrCtaIdxXyToBatchIdx, mPtrCtaIdxXyToMnLimit, mPtrNumNonExitingCtas)"); + + bool const useSingleCluster = (smMajor >= 9) && (data.mNumTokens <= 1024); + if (!useSingleCluster) + { + TLLM_CHECK_WITH_INFO( + data.mPtrExpertCounts != nullptr, "When #tokens is large, `mPtrExpertCounts` is a required input."); + } + else + { + data.mPtrExpertCounts = nullptr; // Set it to nullptr for single-cluster code path, as it won't be used + } + + // Number of blocks we can use in the cooperative kernel + static int const smCount = tensorrt_llm::common::getMultiProcessorCount(); + // WAR: Reserve 8 SMs for overlapping kernels. + int const numBlocksCoop = smCount - 8; + // Maximum number of tokens supported by the kernel using a cooperative launch. + int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK; + + // Last routing kernel: disable overlap so GEMM waits via stream serialization. + bool const savedAllowOverlap = data.mPdlAllowOverlap; + + if (useSingleCluster) + { + data.mPdlAllowOverlap = false; + launchClusterKernel(data, numThreadsHist, stream); + } + else if ((smMajor >= 9) && (data.mNumTokens <= maxTokensCoop)) + { + data.mPdlAllowOverlap = false; + launchCoopKernel(data, numBlocksCoop, numThreadsHist, stream); + } + else + { + const int32_t expandedIdxSize = data.mNumTokens * data.mTopK; + const int32_t histogramEltsPerBlock = 8 * numThreadsHist; + const int32_t offsetEltsPerBlock = NumEltsPerOffsetTilePerThread * numThreadsHist; + const int32_t maxNumBlocks = 1024; + + int const numBlocksHistogram + = std::min((expandedIdxSize + histogramEltsPerBlock - 1) / histogramEltsPerBlock, maxNumBlocks); + int const numBlocksOffsets + = std::min((expandedIdxSize + offsetEltsPerBlock - 1) / offsetEltsPerBlock, maxNumBlocks); + + launchHistogramKernel(data, numBlocksHistogram, numThreadsHist, stream); + data.mPdlAllowOverlap = false; + launchOffsetsKernel(data, numBlocksOffsets, numThreadsHist, stream); + } + + data.mPdlAllowOverlap = savedAllowOverlap; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#undef LAUNCH_DEEPSEEK_WITH_TOPK +#undef LAUNCH_ROUTING_DEEPSEEK + +} // namespace routingDeepSeek +} // namespace moe::dev::routing diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingDevKernel.h b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingDevKernel.h new file mode 100644 index 00000000000..e3cd64f0494 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingDevKernel.h @@ -0,0 +1,169 @@ +/* + * Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "../DevKernel.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Routing-specific launch macros. +// These macros build on top of LAUNCH_ESC from DevKernel.h. +// +// Unlike the generic LAUNCH_PDL (which instantiates 2 kernels for UsePdl=true/false), +// LAUNCH_PDL_ROUTING instantiates only 1 kernel and passes UsePdl as a runtime field +// in KernelParams. This halves routing kernel instantiations. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define LAUNCH_PDL_ROUTING(data, coopLaunch, types, kernel, numBlocks, numThreads, smemSize, stream) \ + do \ + { \ + cudaLaunchConfig_t config{}; \ + config.gridDim = numBlocks; \ + config.blockDim = numThreads; \ + config.dynamicSmemBytes = smemSize; \ + config.stream = (cudaStream_t) stream; \ + \ + cudaLaunchAttribute attributes[2] = {}; \ + attributes[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; \ + attributes[0].val.programmaticStreamSerializationAllowed = int(data.mUsePdl && data.mPdlAllowOverlap); \ + attributes[1].id = cudaLaunchAttributeCooperative; \ + attributes[1].val.cooperative = int(coopLaunch); \ + config.attrs = attributes; \ + config.numAttrs = 2; \ + auto params = KernelParams::setKernelParams(data); \ + auto kernelTyped = kernel>; \ + if (smemSize > 48 * 1024) \ + TLLM_CUDA_CHECK(cudaFuncSetAttribute(kernelTyped, cudaFuncAttributeMaxDynamicSharedMemorySize, smemSize)); \ + TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, kernelTyped, params)); \ + } while (0) + +// Llama4 dispatch: uses data.mDtypeOutput. +#define LAUNCH_ROUTING_LLAMA4(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream) \ + if (data.mDtypeOutput == tg::Dtype::Fp32) \ + { \ + LAUNCH_PDL_ROUTING(data, coopLaunch, \ + LAUNCH_ESC(float, float, 128 /* Always 128 for llama4*/, 1 /* Always 1 for llama4*/), kernel, numBlocks, \ + numThreads, smemSize, stream); \ + } \ + else if (data.mDtypeOutput == tg::Dtype::Bfloat16) \ + { \ + LAUNCH_PDL_ROUTING(data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, 128 /* Always 128 for llama4*/, 1 /* Always 1 for llama4*/), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } \ + else \ + { \ + TLLM_LOG_ERROR("Unsupported dtypeExpW"); \ + } + +// DeepSeek dispatch: uses data.mDtypeOutput. +#define LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + stream, extraFlag, forceFloatInput, numExperts, numTopExperts) \ + if (data.mDtypeOutput == tg::Dtype::Fp32 && extraFlag) \ + { \ + LAUNCH_PDL_ROUTING(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, numTopExperts, true), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } \ + else if (data.mDtypeOutput == tg::Dtype::Fp32) \ + { \ + LAUNCH_PDL_ROUTING(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, numTopExperts, false), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } \ + else if (data.mDtypeOutput == tg::Dtype::Bfloat16 && extraFlag && forceFloatInput) \ + { \ + LAUNCH_PDL_ROUTING(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, numExperts, numTopExperts, true), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } \ + else if (data.mDtypeOutput == tg::Dtype::Bfloat16 && extraFlag) \ + { \ + LAUNCH_PDL_ROUTING(data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, numTopExperts, true), kernel, numBlocks, numThreads, \ + smemSize, stream); \ + } \ + else if (data.mDtypeOutput == tg::Dtype::Bfloat16 && forceFloatInput) \ + { \ + LAUNCH_PDL_ROUTING(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, numExperts, numTopExperts, false), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } \ + else if (data.mDtypeOutput == tg::Dtype::Bfloat16) \ + { \ + LAUNCH_PDL_ROUTING(data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, numTopExperts, false), kernel, numBlocks, numThreads, \ + smemSize, stream); \ + } \ + else \ + { \ + TLLM_LOG_ERROR("Unsupported dtypeExpW"); \ + } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// routingCustom dispatch: uses data.mDtypeOutput (OutputT) and data.mDtypeInput (InputT). +// These are routingCustom::Data fields, NOT used by DeepSeek/Llama4 macros. +// Wraps (PreProc, PostProc) into TopKExpertSelect for the standard preprocess→topK→postprocess flow. +#define LAUNCH_ROUTING_WITH_POLICIES( \ + data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, PreProc, PostProc, numExperts, numTopExperts) \ + if (data.mDtypeOutput == tg::Dtype::Fp32) \ + { \ + LAUNCH_PDL_ROUTING(data, coopLaunch, \ + LAUNCH_ESC(float, float, numExperts, numTopExperts, TopKExpertSelect), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } \ + else if (data.mDtypeOutput == tg::Dtype::Bfloat16 && data.mDtypeInput == tg::Dtype::Fp32) \ + { \ + LAUNCH_PDL_ROUTING(data, coopLaunch, \ + LAUNCH_ESC(float, __nv_bfloat16, numExperts, numTopExperts, TopKExpertSelect), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } \ + else if (data.mDtypeOutput == tg::Dtype::Bfloat16) \ + { \ + LAUNCH_PDL_ROUTING(data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, numTopExperts, TopKExpertSelect), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } \ + else \ + { \ + TLLM_LOG_ERROR("Unsupported dtypeOutput"); \ + } + +// routingCustom dispatch for custom ExpertSelectPolicy types that don't use PreProc/PostProc. +// Use this when the policy does NOT follow the standard preprocess→topK→postprocess pattern. +// ExpertSelect must satisfy the ExpertSelectPolicy concept (see RoutingCustomPolicy.cuh). +#define LAUNCH_ROUTING_WITH_EXPERT_SELECT( \ + data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, ExpertSelect, numExperts, numTopExperts) \ + if (data.mDtypeOutput == tg::Dtype::Fp32) \ + { \ + LAUNCH_PDL_ROUTING(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, numTopExperts, ExpertSelect), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } \ + else if (data.mDtypeOutput == tg::Dtype::Bfloat16 && data.mDtypeInput == tg::Dtype::Fp32) \ + { \ + LAUNCH_PDL_ROUTING(data, coopLaunch, \ + LAUNCH_ESC(float, __nv_bfloat16, numExperts, numTopExperts, ExpertSelect), kernel, numBlocks, numThreads, \ + smemSize, stream); \ + } \ + else if (data.mDtypeOutput == tg::Dtype::Bfloat16) \ + { \ + LAUNCH_PDL_ROUTING(data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, numTopExperts, ExpertSelect), kernel, numBlocks, \ + numThreads, smemSize, stream); \ + } \ + else \ + { \ + TLLM_LOG_ERROR("Unsupported dtypeOutput"); \ + } + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingFromTopKIds.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingFromTopKIds.cu new file mode 100644 index 00000000000..ca2308cf730 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingFromTopKIds.cu @@ -0,0 +1,139 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "RoutingCustomPolicy.cuh" +#include "RoutingKernel.cuh" +#include "RoutingKernel.h" +#include + +namespace moe::dev::routing +{ +namespace routingCustom +{ +// Forward declarations of launch functions +void launchBlockKernel(Data const& data, uint32_t numThreadsHist, void* stream); +void launchDynBlockKernel(Data const& data, uint32_t numThreadsHist, void* stream); +void launchClusterKernel(Data const& data, void* stream); +void launchCoopKernel(Data const& data, int numBlocksCoop, uint32_t numThreadsHist, void* stream); +void launchInitExpertCounts(Data const& data, uint32_t numThreadsHist, void* stream); +void launchHistogramKernel(Data const& data, int numBlocksHistogram, uint32_t numThreadsHist, void* stream); +void launchOffsetsKernel(Data const& data, int numBlocksOffsets, uint32_t numThreadsHist, void* stream); +} // namespace routingCustom + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Implementation of shared post-topK pipeline for all routing methods. +// When topK is already computed (mPtrTopKIds or mPtrTopKPacked), we don't need +// routing-method-specific logic, so all methods can use the same workflow. +// This function handles all path selection: single-block, single-cluster, coop, multi-kernel. +template +void runPostTopKPipeline(DataType const& data, uint32_t /*numThreadsHist*/, void* stream) +{ + // Convert to routingCustom::Data for launching (kernels are shared) + routingCustom::Data customData; + // Copy base fields + static_cast(customData) = static_cast(data); + // Set routingCustom-specific defaults (not needed for utility kernels) + customData.mDtypeOutput = data.mDtypeOutput; + // The post-TopK kernels don't read routing logits (mPtrInput), only mPtrTopKPacked. + // Set mDtypeInput = mDtypeOutput so the dispatched template is , + // avoiding an unnecessary mixed-type instantiation. + customData.mDtypeInput = data.mDtypeOutput; + customData.mPreprocessType = RoutingPreprocessType::None; + // Softmax is chosen for its broad tier coverage, not because we need softmax. + // The TopKIds/TopKPacked branches never call ExpertSelectPolicy::apply(), + // so the postprocess is never executed. Using Softmax avoids extra template + // instantiations by reusing tiers already compiled for other models. + customData.mPostprocessType = RoutingPostprocessType::Softmax; + + // Recompute numThreadsHist using routingCustom's expert tiers, since we launch custom kernels. + // Different routing methods (DeepSeek, Llama4) may have different expert tier thresholds + // that don't match routingCustom's tiers (128, 512, 2048). + uint32_t const numThreadsHist + = std::min(1024u, static_cast(routingCustom::getMaxNumExperts(data.mNumExperts))); + + // Determine which path to use based on token count + static int const smMajor = tensorrt_llm::common::getSMVersion() / 10; + bool const useStaticBlock = data.mNumTokens <= routingCustom::BlockKernelMaxNumTokens; + bool const useDynBlock = !useStaticBlock && data.mNumTokens <= routingCustom::DynBlockKernelMaxNumTokens + && data.mNumExperts <= routingCustom::DynBlockKernelMaxNumExperts; + bool const useSingleBlock = useStaticBlock || useDynBlock; + bool const useSingleCluster = (smMajor >= 9) && (data.mNumTokens <= routingCustom::MaxNumTokensSingleClusterScores); + + routingCustom::Data lastKernelData = customData; + lastKernelData.mPdlAllowOverlap = false; + + if (useDynBlock) + { + routingCustom::launchDynBlockKernel(lastKernelData, numThreadsHist, stream); + } + else if (useStaticBlock) + { + routingCustom::launchBlockKernel(lastKernelData, numThreadsHist, stream); + } + else if (useSingleCluster) + { + routingCustom::launchClusterKernel(lastKernelData, stream); + } + else + { + bool const canUseCoop = (smMajor >= 9) && (data.mNumExperts <= 1024) && (data.mPtrPermutedIdxSize != nullptr); + bool useCoop = false; + int numBlocksCoop = 0; + + if (canUseCoop) + { + static int const smCount = tensorrt_llm::common::getMultiProcessorCount(); + numBlocksCoop = smCount - 8; + int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK; + useCoop = (data.mNumTokens <= maxTokensCoop); + } + + TLLM_CHECK_WITH_INFO( + data.mPtrExpertCounts != nullptr, "When #tokens is large, `mPtrExpertCounts` is a required input."); + + if (useCoop) + { + routingCustom::launchInitExpertCounts(customData, numThreadsHist, stream); + routingCustom::launchCoopKernel(lastKernelData, numBlocksCoop, numThreadsHist, stream); + } + else + { + routingCustom::launchInitExpertCounts(customData, numThreadsHist, stream); + + int32_t const expandedIdxSize = data.mNumTokens * data.mTopK; + int32_t const histogramEltsPerBlock = 8 * numThreadsHist; + int32_t const offsetEltsPerBlock = NumEltsPerOffsetTilePerThread * numThreadsHist; + int32_t const maxNumBlocks = 1024; + + int const numBlocksHistogram + = std::min((expandedIdxSize + histogramEltsPerBlock - 1) / histogramEltsPerBlock, maxNumBlocks); + int const numBlocksOffsets + = std::min((expandedIdxSize + offsetEltsPerBlock - 1) / offsetEltsPerBlock, maxNumBlocks); + + routingCustom::launchHistogramKernel(customData, numBlocksHistogram, numThreadsHist, stream); + routingCustom::launchOffsetsKernel(lastKernelData, numBlocksOffsets, numThreadsHist, stream); + } + } +} + +// Explicit instantiations for the three routing method Data types +template void runPostTopKPipeline(routingCustom::Data const&, uint32_t, void*); +template void runPostTopKPipeline(routingDeepSeek::Data const&, uint32_t, void*); +template void runPostTopKPipeline(routingLlama4::Data const&, uint32_t, void*); + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace moe::dev::routing diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.cuh b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingKernel.cuh similarity index 73% rename from cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.cuh rename to cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingKernel.cuh index 4bc7b56aa18..f5c57b1611c 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.cuh +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingKernel.cuh @@ -15,7 +15,7 @@ */ #pragma once -#include "DevKernel.h" +#include "RoutingDevKernel.h" #include "RoutingKernel.h" #include "RoutingKernelTopK.cuh" @@ -48,6 +48,21 @@ static constexpr int NumEltsPerOffsetTilePerThread = 8; //////////////////////////////////////////////////////////////////////////////////////////////////// +/// Dereference a type-erased pointer at the given index, reading the value in its native dtype. +/// Returns float since routing computations are done in float for numerical stability. +__forceinline__ __device__ float loadScalar(void const* ptr, int idx, batchedGemm::trtllm::gen::Dtype dtype) +{ + namespace tg = batchedGemm::trtllm::gen; + switch (dtype) + { + case tg::Dtype::Fp32: return static_cast(ptr)[idx]; + case tg::Dtype::Bfloat16: return static_cast(static_cast<__nv_bfloat16 const*>(ptr)[idx]); + default: return 0.f; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + static __device__ inline float sigmoid_accurate(float x) { return 0.5f * tanhf(0.5f * x) + 0.5f; @@ -391,7 +406,7 @@ __device__ void routingPermutation(KernelParams params, PackedScoreIdx #pragma unroll for (int e = 0; e < ExpertsPerThread; e++) { - if constexpr (KernelParams::isPow2) + if (params.mIsPow2) { numCta[e] = divUpLog2(count[e], params.mPaddingLog2); } @@ -411,7 +426,6 @@ __device__ void routingPermutation(KernelParams params, PackedScoreIdx int expert = threadIdx.x * ExpertsPerThread + e; if (expert < params.mNumExperts) { - // Strided loop to share this work between blocks. for (int32_t cta = clusterBlockRank; cta < numCta[e]; cta += NumBlocksPerCluster) { const int32_t localExpertIdx @@ -419,7 +433,7 @@ __device__ void routingPermutation(KernelParams params, PackedScoreIdx params.mPtrCtaIdxXyToBatchIdx[ctaOffset[e] + cta] = localExpertIdx; int32_t mnLimit1; int32_t mnLimit2; - if constexpr (KernelParams::isPow2) + if (params.mIsPow2) { mnLimit1 = mulLog2(ctaOffset[e] + cta + 1, params.mPaddingLog2); mnLimit2 = mulLog2(ctaOffset[e], params.mPaddingLog2) + count[e]; @@ -432,9 +446,8 @@ __device__ void routingPermutation(KernelParams params, PackedScoreIdx params.mPtrCtaIdxXyToMnLimit[ctaOffset[e] + cta] = min(mnLimit1, mnLimit2); } - // get the padded offset associated with this expert int32_t offset; - if constexpr (KernelParams::isPow2) + if (params.mIsPow2) { offset = mulLog2(ctaOffset[e], params.mPaddingLog2); } @@ -443,16 +456,14 @@ __device__ void routingPermutation(KernelParams params, PackedScoreIdx offset = mulTileN(ctaOffset[e], params.mTileTokensDim); } - // write expert offsets to shared smemExpertOffset[expert] = offset + blockExpertOffset[e]; } } - // write out padded count if (clusterBlockRank == 0 && warpIdx == NumWarps - 1 && cute::elect_one_sync()) { int32_t permutedIdxSize; - if constexpr (KernelParams::isPow2) + if (params.mIsPow2) { permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); } @@ -472,17 +483,6 @@ __device__ void routingPermutation(KernelParams params, PackedScoreIdx // implement break with EXIT. __cluster_barrier_wait(); - // trigger the secondary kernel when using PDL - // We can't do it earlier because FC1 depends on the mPtrCtaIdxXyToBatchIdx, - // mPtrCtaIdxXyToMnLimit, mPtrNumNonExitingCtas and mPtrTotalNumPaddedTokens - // TODO: this is not sufficient to ensure visibility in the next kernel! -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - if constexpr (KernelParams::UsePdl) - { - cudaTriggerProgrammaticLaunchCompletion(); - } -#endif - // each thread has the same "expanded indexes" assigned to it as above // at this point, we know the final offsets of experts and the offsets within // experts, which allows writing the final index values @@ -515,6 +515,18 @@ __device__ void routingPermutation(KernelParams params, PackedScoreIdx params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx; } } + + // Trigger the secondary kernel AFTER all global memory writes are complete. + // The downstream kernels (permute, FC1 GEMM) depend on mPtrCtaIdxXyToBatchIdx, + // mPtrCtaIdxXyToMnLimit, mPtrNumNonExitingCtas, mPtrPermutedIdxSize, AND + // mPtrExpandedIdxToPermutedIdx / mPtrPermutedIdxToTokenIdx. + // Triggering before the permutation writes causes the consumer to read stale data → NaN. +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + if (params.mUsePdl) + { + cudaTriggerProgrammaticLaunchCompletion(); + } +#endif } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -556,11 +568,10 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa __syncthreads(); #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - // Wait on primary grid and trigger secondary kernel. - if constexpr (KernelParams::UsePdl) + // Wait on primary grid (but do NOT trigger yet — trigger after atomicAdd to mPtrExpertCounts). + if (params.mUsePdl) { cudaGridDependencySynchronize(); - cudaTriggerProgrammaticLaunchCompletion(); } #endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) @@ -640,6 +651,15 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa atomicAdd(¶ms.mPtrExpertCounts[expert], localExpertCount); } } + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + // Trigger AFTER all atomicAdds to mPtrExpertCounts are done, so the next kernel + // (routingIndicesOffsetsKernel) sees the complete histogram. + if (params.mUsePdl) + { + cudaTriggerProgrammaticLaunchCompletion(); + } +#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -673,7 +693,7 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) // Wait on primary grid. - if constexpr (KernelParams::UsePdl) + if (params.mUsePdl) { cudaGridDependencySynchronize(); } @@ -704,7 +724,7 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa #pragma unroll for (int e = 0; e < ExpertsPerThread; e++) { - if constexpr (KernelParams::isPow2) + if (params.mIsPow2) { numCta[e] = divUpLog2(count[e], params.mPaddingLog2); } @@ -723,9 +743,8 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa int expert = threadIdx.x * ExpertsPerThread + e; if (expert < params.mNumExperts) { - // Get the padded offset associated with this expert int32_t offset; - if constexpr (KernelParams::isPow2) + if (params.mIsPow2) { offset = mulLog2(ctaOffset[e], params.mPaddingLog2); } @@ -734,19 +753,16 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa offset = mulTileN(ctaOffset[e], params.mTileTokensDim); } - // Write expert offsets to shared smemExpertOffset[expert] = offset; } } - // Sync to make expert offsets available to all threads. __syncthreads(); - // The first block writes out padded count (use last warp of actual thread count) if (blockIdx.x == 0 && warpIdx == NumThreadsBlock / WarpSize - 1 && cute::elect_one_sync()) { int32_t permutedIdxSize; - if constexpr (KernelParams::isPow2) + if (params.mIsPow2) { permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); } @@ -764,7 +780,6 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa int expert = threadIdx.x * ExpertsPerThread + e; if (expert < params.mNumExperts) { - // Strided loop to share this work between blocks. for (int32_t cta = blockIdx.x; cta < numCta[e]; cta += gridDim.x) { const int32_t localExpertIdx @@ -772,7 +787,7 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa params.mPtrCtaIdxXyToBatchIdx[ctaOffset[e] + cta] = localExpertIdx; int32_t mnLimit1; int32_t mnLimit2; - if constexpr (KernelParams::isPow2) + if (params.mIsPow2) { mnLimit1 = mulLog2(ctaOffset[e] + cta + 1, params.mPaddingLog2); mnLimit2 = mulLog2(ctaOffset[e], params.mPaddingLog2) + count[e]; @@ -965,7 +980,7 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa // Trigger secondary kernel. // Note: this does not guarantee the visibility of prior writes unless the consumer executes a // dependency sync. - if constexpr (KernelParams::UsePdl) + if (params.mUsePdl) { cudaTriggerProgrammaticLaunchCompletion(); } @@ -988,7 +1003,7 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) // Wait on primary grid. - if constexpr (KernelParams::UsePdl) + if (params.mUsePdl) { cudaGridDependencySynchronize(); } @@ -998,11 +1013,307 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelPa #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) // Wait on primary grid. - if constexpr (KernelParams::UsePdl) + if (params.mUsePdl) { cudaTriggerProgrammaticLaunchCompletion(); } #endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) } + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Cooperative launch kernel: fuses histogram + offsets computation for medium token counts. +// This kernel is shared by routingCustom, routingDeepSeek, and can be used by other routing methods. +// It uses cooperative groups to synchronize across multiple CTAs and compute expert counts, +// offsets, and permutation indices in a single kernel launch. +// +// Requirements: +// - MaxNumExperts <= 1024 (enforced by static_assert) +// - SM90+ architecture (cooperative groups) +// - mPtrPermutedIdxSize must be non-null (needed for permutation) +// +// The kernel handles both mPtrTopKIds and mPtrTopKPacked input formats. +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) +template +__global__ void __launch_bounds__(KernelParams::MaxNumExperts) routingIndicesCoopKernel(KernelParams params) +{ + // number of experts is bounded by number of threads (coop kernel requires MaxNumExperts <= 1024) + using OutputT = typename KernelParams::OutputT; + static constexpr int MaxNumExperts = KernelParams::MaxNumExperts; + static constexpr int NumThreads = MaxNumExperts; + static_assert(MaxNumExperts <= 1024, "Coop kernel requires MaxNumExperts <= 1024"); + + __shared__ int32_t __attribute((aligned(128))) smemExpertCount[MaxNumExperts]; + __shared__ int32_t __attribute((aligned(128))) smemExpertOffset[MaxNumExperts]; + // needed for the exclusive sum of token offsets + using Scan = cub::BlockScan; + __shared__ typename Scan::TempStorage tempStorage; + // 64 elements -> 128+ registers. Above that we may start to see spilling to local memory. + static constexpr int MaxExpandedIdxPerThread = 64; + + // Initialize grid. + cg::grid_group grid = cg::this_grid(); + // Note: the following is more efficient than grid.block_index() because we don't use y and z. + int32_t const gridBlockIdx = blockIdx.x; + int32_t const gridThreadIdx = NumThreads * gridBlockIdx + threadIdx.x; + int32_t const numBlocks = gridDim.x; + int32_t const numThreadsPerGrid = numBlocks * NumThreads; + + int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); + + auto expandedIdxSize = params.mNumTokens * params.mTopK; + + // pre-fill the counts with 0 — each thread represents one expert + smemExpertCount[threadIdx.x] = 0; + __syncthreads(); + + // then wait on primary grid + if (params.mUsePdl) + { + cudaGridDependencySynchronize(); + } + + // each thread keeps has some number of "expanded indexes" assigned to it + // for each of these, we keep the associated expert and offset within expert in registers + int32_t expertIndexes[MaxExpandedIdxPerThread]; + int32_t expertOffsets[MaxExpandedIdxPerThread]; + auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2; + // In order to avoid a serialization LDG-ATOMS-LDG-ATOMS-..., we skip multiple iterations at a + // time, and branch between a fast path without bound checks and a slow path with bound checks. + int constexpr IterStride = 4; + static_assert(MaxExpandedIdxPerThread % IterStride == 0); + + // Define a lambda to avoid code duplication in both branches. + // Use shared device function for expert index extraction. + auto loopBody = [&](int ii, int expandedIdx) + { + int32_t expertIdx = getExpertIdxFromInputWithWeights(params, expandedIdx, params.mPtrTopKWeights); + expertIndexes[ii] = expertIdx; + // check whether this expert is local to our GPU at all and ignore if not + auto localExpertIdx = expertIdx - params.mLocalExpertsStartIdx; + auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent + && (localExpertIdx & ((1 << params.mLocalExpertsStrideLog2) - 1)) == 0; + expertOffsets[ii] = isLocalExpert ? atomicAdd(smemExpertCount + expertIdx, 1) : 0; + }; + +#pragma unroll + for (int32_t ii0 = 0; ii0 < MaxExpandedIdxPerThread; ii0 += IterStride) + { + // Whether it's safe to do multiple iterations without bound checks. + bool const takeFastPath = (ii0 + IterStride) * numThreadsPerGrid <= expandedIdxSize; + if (takeFastPath) + { +#pragma unroll + for (int32_t jj = 0; jj < IterStride; jj++) + { + int const ii = ii0 + jj; + auto expandedIdx = static_cast(gridThreadIdx) + ii * numThreadsPerGrid; + loopBody(ii, expandedIdx); + } + } + else + { + bool doBreak = false; +#pragma unroll + for (int32_t jj = 0; jj < IterStride; jj++) + { + int const ii = ii0 + jj; + auto expandedIdx = static_cast(gridThreadIdx) + ii * numThreadsPerGrid; + if (expandedIdx >= expandedIdxSize) + { + doBreak = true; + break; + } + loopBody(ii, expandedIdx); + } + if (doBreak) + { + break; + } + } + } + + // Make histogram (token counts per expert) available to all threads in the block. + __syncthreads(); + + // + // Each thread now represents one expert + // + + // Add the local bin count to the common bin count and get a per-CTA offset. + int32_t const localExpertCount = smemExpertCount[threadIdx.x]; + + int32_t blockExpertOffset = 0; + if (threadIdx.x < params.mNumExperts) + { + blockExpertOffset = atomicAdd(¶ms.mPtrExpertCounts[threadIdx.x], localExpertCount); + } + + // Sync to wait for completion of the histogram reduction. + grid.sync(); + + // Get total count for this expert. + int32_t count = (threadIdx.x < params.mNumExperts) ? params.mPtrExpertCounts[threadIdx.x] : 0; + + // Note: the scan is redundant in all CTAs, but doing it in only 1 CTA would be worse for latency. + + // Compute the runtime config for projections + // Whether or not an expert is local is taken into account when smemExpertCount is computed + // so we do not need to take it into account here. + + int32_t numCta; + if (params.mIsPow2) + { + numCta = divUpLog2(count, params.mPaddingLog2); + } + else + { + numCta = divUpTileN(count, params.mTileTokensDim); + } + + int32_t ctaOffset; + int32_t numNonExitingCtas; + Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas); + + for (int32_t cta = gridBlockIdx; cta < numCta; cta += numBlocks) + { + const int32_t localExpertIdx = (threadIdx.x - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2; + params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx; + int32_t mnLimit1; + int32_t mnLimit2; + if (params.mIsPow2) + { + mnLimit1 = mulLog2(ctaOffset + cta + 1, params.mPaddingLog2); + mnLimit2 = mulLog2(ctaOffset, params.mPaddingLog2) + count; + } + else + { + mnLimit1 = mulTileN(ctaOffset + cta + 1, params.mTileTokensDim); + mnLimit2 = mulTileN(ctaOffset, params.mTileTokensDim) + count; + } + params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = min(mnLimit1, mnLimit2); + } + + int32_t offset; + if (params.mIsPow2) + { + offset = mulLog2(ctaOffset, params.mPaddingLog2); + } + else + { + offset = mulTileN(ctaOffset, params.mTileTokensDim); + } + int32_t permutedIdxSize; + if (params.mIsPow2) + { + permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); + } + else + { + permutedIdxSize = mulTileN(numNonExitingCtas, params.mTileTokensDim); + } + + // write out padded count + if (gridBlockIdx == 0 && warpIdx == NumThreads / WarpSize - 1 && cute::elect_one_sync()) + { + params.mPtrPermutedIdxSize[0] = permutedIdxSize; + params.mPtrNumNonExitingCtas[0] = numNonExitingCtas; + } + + // write expert offsets to shared + smemExpertOffset[threadIdx.x] = offset + blockExpertOffset; + + // make expert offsets available to all threads + __syncthreads(); + + // each thread has the same "expanded indexes" assigned to it as above + // at this point, we know the final offsets of experts and the offsets within + // experts, which allows writing the final index values +#pragma unroll + for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ++ii) + { + auto expandedIdx = static_cast(gridThreadIdx) + ii * numThreadsPerGrid; + if (expandedIdx >= expandedIdxSize) + { + break; + } + auto expertIdx = expertIndexes[ii]; + // check whether this expert is local to our GPU at all + auto localExpertIdx = static_cast(expertIdx) - params.mLocalExpertsStartIdx; + auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent + && (localExpertIdx & ((1 << params.mLocalExpertsStrideLog2) - 1)) == 0; + auto tokenIdx = expandedIdx / params.mTopK; + auto permutedIdx = isLocalExpert ? int32_t{smemExpertOffset[expertIdx]} + expertOffsets[ii] : int32_t{-1}; + if (params.mPtrExpandedIdxToPermutedIdx != nullptr) + { + params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx; + } + if (params.mPtrPermutedIdxToExpandedIdx != nullptr && isLocalExpert) + { + params.mPtrPermutedIdxToExpandedIdx[permutedIdx] = expandedIdx; + } + if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocalExpert) + { + params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx; + } + } + + // Trigger the secondary kernel AFTER all global memory writes (including permutation indices). + // The downstream kernels depend on all routing outputs being visible. + if (params.mUsePdl) + { + cudaTriggerProgrammaticLaunchCompletion(); + } +} +#else +template +__global__ void routingIndicesCoopKernel(KernelParams params) +{ + assert(false && "routingIndicesCoopKernel is only supported on SM90+ architectures"); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Shared device functions for coop kernel (used by both routingCustom and routingDeepSeek) +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Device function to extract expert index from either mPtrTopKIds or mPtrTopKPacked. +// This is the only difference between routingCustom and routingDeepSeek coop kernels. +// For routingCustom: also writes to mPtrTopKWeights if provided. +// For routingDeepSeek: simpler version that doesn't write weights. +template +__forceinline__ __device__ int32_t getExpertIdxFromInput(KernelParams const& params, int32_t expandedIdx) +{ + if (params.mPtrTopKIds != nullptr) + { + return params.mPtrTopKIds[expandedIdx]; + } + else + { + return params.mPtrTopKPacked[expandedIdx].idx; + } +} + +// Overload for routingCustom that also writes topK weights if needed. +template +__forceinline__ __device__ int32_t getExpertIdxFromInputWithWeights( + KernelParams const& params, int32_t expandedIdx, typename KernelParams::OutputT* topKWeights) +{ + if (params.mPtrTopKIds != nullptr) + { + return params.mPtrTopKIds[expandedIdx]; + } + else + { + PackedScoreIdx scoreIdx = params.mPtrTopKPacked[expandedIdx]; + if (topKWeights != nullptr) + { + topKWeights[expandedIdx] = static_cast(scoreIdx.score); + } + return scoreIdx.idx; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace routing } // namespace moe::dev diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.h b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingKernel.h similarity index 67% rename from cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.h rename to cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingKernel.h index 3daa1848e5d..8ca6c657e40 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.h +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingKernel.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -44,6 +44,12 @@ struct DataBase { bool mUsePdl{false}; + // Controls programmaticStreamSerializationAllowed launch attribute independently of mUsePdl. + // When false, the launch attribute is 0 even if mUsePdl is true, preventing the NEXT kernel + // from starting early. The kernel still does cudaGridDependencySynchronize/Trigger internally. + // Set to false on the LAST routing kernel so downstream GEMM waits via stream serialization. + bool mPdlAllowOverlap{true}; + // optional: only used as an intermediate buffer when the number of tokens is large. // dim: max([2*NumThreads] = [512], mNumExperts*2) int32_t* mPtrExpertCounts{nullptr}; @@ -100,8 +106,10 @@ struct DataBase int32_t mNumTokens; int32_t mNumExperts; int32_t mTopK; - int32_t mPaddingLog2; + // Cluster-wide tile size in token dimension. int32_t mTileTokensDim; + // log2() of the padding size in cluster-wide tile. + int32_t mPaddingLog2; /// For expert parallelization int32_t mLocalExpertsStartIdx; @@ -109,15 +117,16 @@ struct DataBase int32_t mNumLocalExperts; }; -template +template struct KernelParamsBase { using InputT = InputT_; using OutputT = OutputT_; static constexpr int MaxNumExperts = MaxNumExperts_; static constexpr int MaxNumTopExperts = MaxNumTopExperts_; - static constexpr bool UsePdl = UsePdl_; - static constexpr bool isPow2 = isPow2_; + + bool mUsePdl = false; + bool mIsPow2 = false; // Public pointer members int32_t* mPtrExpertCounts = nullptr; @@ -146,6 +155,8 @@ struct KernelParamsBase template void setBaseParams(DataType const& data) { + mUsePdl = data.mUsePdl; + mIsPow2 = data.mPaddingLog2 > 0; mPtrExpertCounts = data.mPtrExpertCounts; mPtrPermutedIdxSize = data.mPtrPermutedIdxSize; mPtrExpandedIdxToPermutedIdx = data.mPtrExpandedIdxToPermutedIdx; @@ -175,12 +186,14 @@ namespace routingDeepSeek //////////////////////////////////////////////////////////////////////////////////////////////////// struct Data : public DataBase { - tg::Dtype mDtypeExpW{tg::Dtype::Bfloat16}; + tg::Dtype mDtypeOutput{tg::Dtype::Bfloat16}; // // Grouped Gemm Launch Config Buffers // void const* mPtrRoutingBias; + // Dtype of the routing bias buffer (Bfloat16 or Fp32). + tg::Dtype mDtypeBias{tg::Dtype::Bfloat16}; int32_t mHiddenDim; // not used int32_t mNumExpertGroups; @@ -190,9 +203,8 @@ struct Data : public DataBase bool mUseRoutingSoftmax; }; -template -struct KernelParams : public KernelParamsBase +template +struct KernelParams : public KernelParamsBase { using InputT = InputT_; using OutputT = OutputT_; @@ -205,7 +217,9 @@ struct KernelParams : public KernelParamsBase*) data.mPtrTopKPacked; - // params.mPtrTopKWeightsFull = static_cast(data.mPtrTopKWeightsFull); - params.mPtrRoutingBias = static_cast(data.mPtrRoutingBias); + params.mPtrRoutingBias = data.mPtrRoutingBias; + params.mDtypeBias = data.mDtypeBias; params.mNumExpertGroups = data.mNumExpertGroups; params.mNumExpertsPerGroup = data.mNumExperts / data.mNumExpertGroups; @@ -247,11 +261,11 @@ namespace routingLlama4 struct Data : public DataBase { - tg::Dtype mDtypeExpW{tg::Dtype::Bfloat16}; + tg::Dtype mDtypeOutput{tg::Dtype::Bfloat16}; }; -template -struct KernelParams : public KernelParamsBase +template +struct KernelParams : public KernelParamsBase { using InputT = InputT_; using OutputT = OutputT_; @@ -277,40 +291,69 @@ void run(Data const& data, void* stream); //////////////////////////////////////////////////////////////////////////////////////////////////// -namespace routingRenormalize +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Routing preprocess/postprocess policy type enums. +// These are used to select the compile-time policy at dispatch time. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +enum class RoutingPreprocessType +{ + None, // No preprocessing before topK + Softmax, // Apply softmax on all expert scores before topK + Sigmoid, // Apply sigmoid(score) for topK selection (Cohere-style, no bias) + SigmoidBias, // Apply sigmoid(score) + bias for topK selection (DeepSeek-style) +}; + +enum class RoutingPostprocessType +{ + None, // No postprocessing after topK + Softmax, // Apply softmax on top-K scores + SumNormalize, // Normalize top-K scores by their sum + ScaledSumNormalize, // Recover sigmoid scores, normalize by sum and scale (DeepSeek-style) +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace routingCustom { //////////////////////////////////////////////////////////////////////////////////////////////////// struct Data : public DataBase { - tg::Dtype mDtypeExpW{tg::Dtype::Fp32}; - tg::Dtype mDtypeElt{tg::Dtype::Bfloat16}; + tg::Dtype mDtypeOutput{tg::Dtype::Fp32}; // OutputT: expert weights dtype (typically Bfloat16) + tg::Dtype mDtypeInput{tg::Dtype::Bfloat16}; // InputT: routing logits dtype (Bfloat16 or Fp32) - bool mDoSoftmaxBeforeTopK{false}; + RoutingPreprocessType mPreprocessType{RoutingPreprocessType::None}; + RoutingPostprocessType mPostprocessType{RoutingPostprocessType::Softmax}; bool mNormTopkProb{true}; // Default value is true for Qwen3 model - // If true, applies softmax normalization after selecting top-K experts. - // Use this for models that require post-selection normalization (e.g., specific Qwen variants). - // Mutually exclusive with mDoSoftmaxBeforeTopK when both normalization paths are active. - // NOTE: Don't need to use this variable for now. - bool mApplySoftmaxAfterTopK{true}; + + // Optional: per-expert routing bias (used by SigmoidBias preprocess). + void const* mPtrRoutingBias{nullptr}; + // Dtype of the routing bias buffer (Bfloat16 or Fp32). Used to read mPtrRoutingBias correctly. + tg::Dtype mDtypeBias{tg::Dtype::Bfloat16}; + // Optional: scaling factor applied to final scores (used by ScaledSumNormalize postprocess). + float mRouteScale{1.0f}; + // Optional: epsilon added to the sum before division to prevent division by zero. + // MiniMax2 uses 1e-20f; DeepSeek uses 0.0f (no epsilon). + float mSumEpsilon{0.0f}; }; -template -struct KernelParams : public KernelParamsBase +template +struct KernelParams : public KernelParamsBase { using InputT = InputT_; using OutputT = OutputT_; + using ExpertSelectPolicy = ExpertSelectPolicy_; - static constexpr bool DoSoftmaxBeforeTopK = DoSoftmaxBeforeTopK_; + // Expert select policy params — empty structs have zero register cost. + using ExpertSelectParams = typename ExpertSelectPolicy::template Params; PackedScoreIdx* mPtrTopKPacked = nullptr; int32_t mTopK = 0; - bool mNormTopkProb = true; - bool mApplySoftmaxAfterTopK = false; + ExpertSelectParams mExpertSelectParams; static KernelParams setKernelParams(Data const& data) { @@ -318,16 +361,33 @@ struct KernelParams : public KernelParamsBase*) data.mPtrTopKPacked; - params.mNormTopkProb = data.mNormTopkProb; - params.mApplySoftmaxAfterTopK = data.mApplySoftmaxAfterTopK; params.mTopK = data.mTopK; + + // Policy populates only the fields it needs from Data. + params.mExpertSelectParams.set(data); return params; } }; void run(Data const& data, void* stream); -} // namespace routingRenormalize +} // namespace routingCustom + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Shared utility for post-topK pipeline when mPtrTopKIds != nullptr. +// All routing methods (Custom, DeepSeek, Llama4) use the same workflow in this case: +// 1. Reset expert counts +// 2. Run histogram kernel +// 3. Run offsets kernel +// Since the kernels are shared and we don't need routing-method-specific logic, +// we can use routingCustom's launch mechanism. +// +// This function works with any Data type that inherits from DataBase. +// Implementation is in RoutingFromTopKIds.cu +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void runPostTopKPipeline(DataType const& data, uint32_t numThreadsHist, void* stream); //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace routing diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernelTopK.cuh b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingKernelTopK.cuh similarity index 100% rename from cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernelTopK.cuh rename to cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingKernelTopK.cuh diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingLlama4.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingLlama4.cu similarity index 92% rename from cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingLlama4.cu rename to cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingLlama4.cu index 3362eb80c1b..28435e548d2 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingLlama4.cu +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routing/RoutingLlama4.cu @@ -106,7 +106,7 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) // then wait on primary grid - if constexpr (KernelParams::UsePdl) + if (params.mUsePdl) { cudaGridDependencySynchronize(); } @@ -165,9 +165,9 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam static_cast(params.mPtrTopKPacked[threadIdx.x].idx)}; if (params.mPtrTopKWeights != nullptr) { - // we also compute the final score here and write it out if required - auto finalScore = OutputT{sigmoid_accurate(float{scoreIdx.score})}; - params.mPtrTopKWeights[threadIdx.x] = finalScore; + // mPtrTopKPacked already contains sigmoid scores (produced by the scores-path + // kernels), so we just pass them through — no need to apply sigmoid again. + params.mPtrTopKWeights[threadIdx.x] = scoreIdx.score; } } } @@ -208,7 +208,7 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam { auto count = getBits(expertCount, ii); int32_t num; - if constexpr (KernelParams::isPow2) + if (params.mIsPow2) { num = divUpLog2(count, params.mPaddingLog2); } @@ -231,7 +231,7 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam { auto count = getBits(expertCount, ii); int32_t finalNumCta; - if constexpr (KernelParams::isPow2) + if (params.mIsPow2) { finalNumCta = divUpLog2(count, params.mPaddingLog2); } @@ -240,14 +240,12 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam finalNumCta = divUpTileN(count, params.mTileTokensDim); } auto expertIdx = threadIdx.x * ExpertsPerThread + ii; - // during the scan for expert offsets, we can already write out - // both `mPtrCtaIdxXyToBatchIdx` and `mPtrCtaIdxXyToMnLimit` for (int cta = 0; cta < finalNumCta; ++cta) { params.mPtrCtaIdxXyToBatchIdx[ctaOffsetExp + cta] = expertIdx; int32_t mnLimit1; int32_t mnLimit2; - if constexpr (KernelParams::isPow2) + if (params.mIsPow2) { mnLimit1 = mulLog2(ctaOffsetExp + cta + 1, params.mPaddingLog2); mnLimit2 = mulLog2(ctaOffsetExp, params.mPaddingLog2) + count; @@ -266,7 +264,7 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam if (cute::elect_one_sync()) { int32_t permutedIdxSize; - if constexpr (KernelParams::isPow2) + if (params.mIsPow2) { permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); } @@ -281,7 +279,7 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) // we can trigger the next kernel at this point - if constexpr (KernelParams::UsePdl) + if (params.mUsePdl) { cudaTriggerProgrammaticLaunchCompletion(); } @@ -294,7 +292,7 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam // of registers auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2; int32_t finalExpertOffset[ExpertsPerThread]; - if constexpr (KernelParams::isPow2) + if (params.mIsPow2) { finalExpertOffset[0] = mulLog2(ctaOffset, params.mPaddingLog2); } @@ -306,7 +304,7 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam for (int ii = 1; ii < ExpertsPerThread; ++ii) { int32_t tmp; - if constexpr (KernelParams::isPow2) + if (params.mIsPow2) { tmp = divUpMulLog2(getBits(expertCount, ii - 1), params.mPaddingLog2); } @@ -387,7 +385,7 @@ __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(Nu auto warp = cg::tiled_partition(block); // then wait on primary grid - if constexpr (KernelParams::UsePdl) + if (params.mUsePdl) { cudaGridDependencySynchronize(); } @@ -480,7 +478,7 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) routingIndicesHis #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) // Wait on primary grid and trigger secondary kernel. - if constexpr (KernelParams::UsePdl) + if (params.mUsePdl) { cudaGridDependencySynchronize(); cudaTriggerProgrammaticLaunchCompletion(); @@ -546,10 +544,22 @@ void run(Data const& data, void* stream) { TLLM_CHECK_WITH_INFO(data.mPtrTopKPacked != nullptr || data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr, "Routing kernel requires at least one input parameter"); - if (data.mPtrTopKIds != nullptr) + // When topK is already computed (mPtrTopKIds or mPtrTopKPacked without scores), + // delegate to the shared post-topK pipeline. This avoids Llama4-specific issues: + // - The Llama4 cluster kernel loads one token per warp but useSingleCluster uses + // the thread-based capacity, causing unprocessed tokens for medium token counts. + // - The Llama4 device kernel applies sigmoid to packed scores that may already + // contain sigmoid values (produced by the scores-path kernels). + if (data.mPtrTopKIds != nullptr || (data.mPtrTopKPacked != nullptr && data.mPtrScores == nullptr)) { - TLLM_CHECK_WITH_INFO(data.mPtrTopKWeights != nullptr, - "When mPtrTopKIds is provided, mPtrTopKWeights must also be provided for Llama4 routing."); + if (data.mPtrTopKIds != nullptr) + { + TLLM_CHECK_WITH_INFO(data.mPtrTopKWeights != nullptr, + "When mPtrTopKIds is provided, mPtrTopKWeights must also be provided for Llama4 routing."); + } + int const numThreadsHist = getMaxNumExperts(data.mNumExperts); + runPostTopKPipeline(data, numThreadsHist, stream); + return; } TLLM_CHECK_WITH_INFO(data.mPtrPermutedIdxSize != nullptr && data.mPtrCtaIdxXyToBatchIdx != nullptr && data.mPtrCtaIdxXyToMnLimit != nullptr && data.mPtrNumNonExitingCtas != nullptr, @@ -563,15 +573,16 @@ void run(Data const& data, void* stream) TLLM_CHECK_WITH_INFO( data.mNumExperts % 4 == 0, "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts); + // After this point, mPtrTopKIds is guaranteed to be nullptr. + // Input is either mPtrScores (raw logits) or mPtrTopKPacked (topK already computed, needs sigmoid). bool const useSingleWarp = (data.mPtrScores == nullptr && data.mNumTokens <= WarpKernelMaxNumTokens) || data.mNumTokens < WarpKernelMaxNumTokens; - bool const useSingleCluster = data.mNumTokens <= ((data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr) - ? MaxNumTokensSingleClusterScores - : MaxNumTokensSingleCluster); + bool const useSingleCluster = data.mNumTokens + <= ((data.mPtrScores != nullptr) ? MaxNumTokensSingleClusterScores : MaxNumTokensSingleCluster); if (!useSingleCluster) { - TLLM_CHECK_WITH_INFO((data.mPtrTopKPacked != nullptr || data.mPtrTopKIds != nullptr), - "When #tokens is large, `mPtrTopKPacked` or `mPtrTopKIds` is a required input."); + TLLM_CHECK_WITH_INFO( + data.mPtrTopKPacked != nullptr, "When #tokens is large, `mPtrTopKPacked` is a required input."); TLLM_CHECK_WITH_INFO( data.mPtrExpertCounts != nullptr, "When #tokens is large, `mPtrExpertCounts` is a required input."); } @@ -606,7 +617,7 @@ void run(Data const& data, void* stream) int const numBlocksOffsets = std::min((expandedIdxSize + offsetEltsPerBlock - 1) / offsetEltsPerBlock, maxNumBlocks); - if (data.mPtrScores != nullptr && data.mPtrTopKIds == nullptr) + if (data.mPtrScores != nullptr) { LAUNCH_ROUTING_LLAMA4(data, /*coopLaunch=*/false, routingIndicesHistogramScoresKernel, maxNumBlocks, numThreadsHist, diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingDeepSeek/RoutingDeepSeekCommon.cuh b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingDeepSeek/RoutingDeepSeekCommon.cuh deleted file mode 100644 index b9673be5efe..00000000000 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingDeepSeek/RoutingDeepSeekCommon.cuh +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "../RoutingKernel.cuh" - -namespace moe::dev::routing -{ -namespace routingDeepSeek -{ - -//////////////////////////////////////////////////////////////////////////////////////////////////// -static constexpr int NumNemotronExperts = 512; -static constexpr int NumKimiK2Experts = 384; -static constexpr int NumDeepseekExperts = 256; -static constexpr int MaxSupportedExpertCount = std::max({NumNemotronExperts, NumKimiK2Experts, NumDeepseekExperts}); -static constexpr int NumTopGroupScores = 2; -static constexpr int MaxNumTopGroups = 4; -static constexpr int MaxNumGroups = 8; - -static constexpr int NumTop8Experts = 8; -static constexpr int NumTop22Experts = 22; -static constexpr int MaxSupportedTopExperts = 32; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -int constexpr getMaxNumExperts(int32_t numExperts) -{ - if (numExperts <= topk::MaxNumExpertsUnit) - { - return topk::MaxNumExpertsUnit; - } - else if (numExperts <= NumDeepseekExperts) - { - return NumDeepseekExperts; - } - else if (numExperts <= NumKimiK2Experts) - { - return NumKimiK2Experts; - } - else if (numExperts <= NumNemotronExperts) - { - return NumNemotronExperts; - } - else - { - TLLM_LOG_ERROR("Unsupported numExperts"); - return 0; - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -// Helper macro: dispatch on topK tier for a given numExperts tier. -#define LAUNCH_DEEPSEEK_WITH_TOPK( \ - data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, forceFloatInput, numExperts) \ - if (data.mTopK <= NumTop8Experts) \ - { \ - LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag1, forceFloatInput, numExperts, NumTop8Experts); \ - } \ - else if (data.mTopK <= NumTop22Experts) \ - { \ - LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag1, forceFloatInput, numExperts, NumTop22Experts); \ - } \ - else \ - { \ - LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag1, forceFloatInput, numExperts, MaxSupportedTopExperts); \ - } - -#define LAUNCH_ROUTING_DEEPSEEK( \ - data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, forceFloatInput) \ - if (data.mNumExperts <= topk::MaxNumExpertsUnit) \ - { \ - LAUNCH_DEEPSEEK_WITH_TOPK(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, \ - forceFloatInput, topk::MaxNumExpertsUnit); \ - } \ - else if (data.mNumExperts <= NumDeepseekExperts) \ - { \ - LAUNCH_DEEPSEEK_WITH_TOPK(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, \ - forceFloatInput, NumDeepseekExperts); \ - } \ - else if (data.mNumExperts <= NumKimiK2Experts) \ - { \ - LAUNCH_DEEPSEEK_WITH_TOPK(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, \ - forceFloatInput, NumKimiK2Experts); \ - } \ - else if (data.mNumExperts <= NumNemotronExperts) \ - { \ - LAUNCH_DEEPSEEK_WITH_TOPK(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, \ - forceFloatInput, NumNemotronExperts); \ - } \ - else \ - { \ - TLLM_LOG_ERROR("Unsupported numExperts"); \ - } - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace routingDeepSeek -} // namespace moe::dev::routing diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingDeepSeek/launchClusterKernel.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingDeepSeek/launchClusterKernel.cu deleted file mode 100644 index 14fc591f4e5..00000000000 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingDeepSeek/launchClusterKernel.cu +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "RoutingDeepSeekCommon.cuh" - -namespace moe::dev::routing -{ -namespace routingDeepSeek -{ - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) -__global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(KernelParams::MaxNumExperts) - routingIndicesClusterKernel(KernelParams params) -{ - using OutputT = typename KernelParams::OutputT; - - int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); - int32_t const clusterBlockRank = blockIdx.x; - - //@todo: try to move it into routingPermutation - // then wait on primary grid - if constexpr (KernelParams::UsePdl) - { - cudaGridDependencySynchronize(); - } - routingPermutation(params, nullptr, warpIdx, clusterBlockRank); -} -#else -__global__ void routingIndicesClusterKernel(KernelParams params) -{ - assert(false && "routingIndicesClusterKernel is only supported on SM90+ architectures"); -} -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void launchClusterKernel(Data& data, int numThreadsHist, void* stream) -{ - LAUNCH_ROUTING_DEEPSEEK(data, - /*coopLaunch=*/false, routingIndicesClusterKernel, NumBlocksPerCluster, numThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace routingDeepSeek -} // namespace moe::dev::routing diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingDeepSeek/launchCoopKernel.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingDeepSeek/launchCoopKernel.cu deleted file mode 100644 index a96db74865d..00000000000 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingDeepSeek/launchCoopKernel.cu +++ /dev/null @@ -1,276 +0,0 @@ -/* - * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "RoutingDeepSeekCommon.cuh" - -namespace moe::dev::routing -{ -namespace routingDeepSeek -{ - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) -__global__ void __launch_bounds__(KernelParams::MaxNumExperts) routingIndicesCoopKernel(KernelParams params) -{ - // number of experts is bounded by number of threads - int constexpr NumThreads = KernelParams::MaxNumExperts; - __shared__ int32_t __attribute((aligned(128))) smemExpertCount[NumThreads]; - __shared__ int32_t __attribute((aligned(128))) smemExpertOffset[NumThreads]; - // needed for the exclusive sum of token offsets - using Scan = cub::BlockScan; - __shared__ typename Scan::TempStorage tempStorage; - // 64 elements -> 128+ registers. Above that we may start to see spilling to local memory. - static constexpr int MaxExpandedIdxPerThread = 64; - - // Initialize grid. - cg::grid_group grid = cg::this_grid(); - // Note: the following is more efficient than grid.block_index() because we don't use y and z. - int32_t const gridBlockIdx = blockIdx.x; - int32_t const gridThreadIdx = NumThreads * gridBlockIdx + threadIdx.x; - int32_t const numBlocks = gridDim.x; - int32_t const numThreadsPerGrid = numBlocks * NumThreads; - - int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); - - auto expandedIdxSize = params.mNumTokens * params.mTopK; - - // pre-fill the counts with 0 - smemExpertCount[threadIdx.x] = 0; - __syncthreads(); - - // then wait on primary grid - if constexpr (KernelParams::UsePdl) - { - cudaGridDependencySynchronize(); - } - - // each thread keeps has some number of "expanded indexes" assigned to it - // for each of these, we keep the associated expert and offset within expert in registers - int32_t expertIndexes[MaxExpandedIdxPerThread]; - int32_t expertOffsets[MaxExpandedIdxPerThread]; - auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2; - // In order to avoid a serialization LDG-ATOMS-LDG-ATOMS-..., we skip multiple iterations at a - // time, and branch between a fast path without bound checks and a slow path with bound checks. - int constexpr IterStride = 4; - static_assert(MaxExpandedIdxPerThread % IterStride == 0); - - // Define a lambda to avoid code duplication in both branches. - auto loopBody = [&](int ii, int expandedIdx) - { - int32_t expertIdx - = params.mPtrTopKIds != nullptr ? params.mPtrTopKIds[expandedIdx] : params.mPtrTopKPacked[expandedIdx].idx; - expertIndexes[ii] = expertIdx; - // check whether this expert is local to our GPU at all and ignore if not - auto localExpertIdx = expertIdx - params.mLocalExpertsStartIdx; - auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent - && (localExpertIdx & ((1 << params.mLocalExpertsStrideLog2) - 1)) == 0; - expertOffsets[ii] = isLocalExpert ? atomicAdd(smemExpertCount + expertIdx, 1) : 0; - }; - -#pragma unroll - for (int32_t ii0 = 0; ii0 < MaxExpandedIdxPerThread; ii0 += IterStride) - { - // Whether it's safe to do multiple iterations without bound checks. - bool const takeFastPath = (ii0 + IterStride) * numThreadsPerGrid <= expandedIdxSize; - if (takeFastPath) - { -#pragma unroll - for (int32_t jj = 0; jj < IterStride; jj++) - { - int const ii = ii0 + jj; - auto expandedIdx = static_cast(gridThreadIdx) + ii * numThreadsPerGrid; - loopBody(ii, expandedIdx); - } - } - else - { - bool doBreak = false; -#pragma unroll - for (int32_t jj = 0; jj < IterStride; jj++) - { - int const ii = ii0 + jj; - auto expandedIdx = static_cast(gridThreadIdx) + ii * numThreadsPerGrid; - if (expandedIdx >= expandedIdxSize) - { - doBreak = true; - break; - } - loopBody(ii, expandedIdx); - } - if (doBreak) - { - break; - } - } - } - - // Make histogram (token counts per expert) available to all threads in the block. - __syncthreads(); - - // - // Each thread now represents one expert - // - - // Add the local bin count to the common bin count and get a per-CTA offset. - int32_t const localExpertCount = smemExpertCount[threadIdx.x]; - - int32_t blockExpertOffset = 0; - if (threadIdx.x < params.mNumExperts) - { - blockExpertOffset = atomicAdd(¶ms.mPtrExpertCounts[threadIdx.x], localExpertCount); - } - - // Sync to wait for completion of the histogram reduction. - grid.sync(); - - // Get total count for this expert. - int32_t count = (threadIdx.x < params.mNumExperts) ? params.mPtrExpertCounts[threadIdx.x] : 0; - - // Note: the scan is redundant in all CTAs, but doing it in only 1 CTA would be worse for latency. - - // Compute the runtime config for projections - // Whether or not an expert is local is taken into account when smemExpertCount is computed - // so we do not need to take it into account here. - - int32_t numCta; - if constexpr (KernelParams::isPow2) - { - numCta = divUpLog2(count, params.mPaddingLog2); - } - else - { - numCta = divUpTileN(count, params.mTileTokensDim); - } - - int32_t ctaOffset; - int32_t numNonExitingCtas; - Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas); - - for (int32_t cta = gridBlockIdx; cta < numCta; cta += numBlocks) - { - const int32_t localExpertIdx = (threadIdx.x - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2; - params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx; - int32_t mnLimit1; - int32_t mnLimit2; - if constexpr (KernelParams::isPow2) - { - mnLimit1 = mulLog2(ctaOffset + cta + 1, params.mPaddingLog2); - mnLimit2 = mulLog2(ctaOffset, params.mPaddingLog2) + count; - } - else - { - mnLimit1 = mulTileN(ctaOffset + cta + 1, params.mTileTokensDim); - mnLimit2 = mulTileN(ctaOffset, params.mTileTokensDim) + count; - } - params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = min(mnLimit1, mnLimit2); - } - - // get the padded offset associated with this expert - int32_t offset; - if constexpr (KernelParams::isPow2) - { - offset = mulLog2(ctaOffset, params.mPaddingLog2); - } - else - { - offset = mulTileN(ctaOffset, params.mTileTokensDim); - } - int32_t permutedIdxSize; - if constexpr (KernelParams::isPow2) - { - permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); - } - else - { - permutedIdxSize = mulTileN(numNonExitingCtas, params.mTileTokensDim); - } - - // write out padded count - if (gridBlockIdx == 0 && warpIdx == NumThreads / WarpSize - 1 && cute::elect_one_sync()) - { - params.mPtrPermutedIdxSize[0] = permutedIdxSize; - params.mPtrNumNonExitingCtas[0] = numNonExitingCtas; - } - - // write expert offsets to shared - smemExpertOffset[threadIdx.x] = offset + blockExpertOffset; - - // make expert offsets available to all threads - __syncthreads(); - - // trigger the secondary kernel when using PDL - // We can't do it earlier because FC1 depends on the mPtrCtaIdxXyToBatchIdx, - // mPtrCtaIdxXyToMnLimit, mPtrNumNonExitingCtas and mPtrTotalNumPaddedTokens - // TODO: this is not sufficient to ensure visibility in the next kernel! - if constexpr (KernelParams::UsePdl) - { - cudaTriggerProgrammaticLaunchCompletion(); - } - -// each thread has the same "expanded indexes" assigned to it as above -// at this point, we know the final offsets of experts and the offsets within -// experts, which allows writing the final index values -#pragma unroll - for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ++ii) - { - auto expandedIdx = static_cast(gridThreadIdx) + ii * numThreadsPerGrid; - if (expandedIdx >= expandedIdxSize) - { - break; - } - auto expertIdx = expertIndexes[ii]; - // check whether this expert is local to our GPU at all - auto localExpertIdx = static_cast(expertIdx) - params.mLocalExpertsStartIdx; - auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent - && (localExpertIdx & ((1 << params.mLocalExpertsStrideLog2) - 1)) == 0; - auto tokenIdx = expandedIdx / params.mTopK; - auto permutedIdx = isLocalExpert ? int32_t{smemExpertOffset[expertIdx]} + expertOffsets[ii] : int32_t{-1}; - if (params.mPtrExpandedIdxToPermutedIdx != nullptr) - { - params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx; - } - if (params.mPtrPermutedIdxToExpandedIdx != nullptr && isLocalExpert) - { - params.mPtrPermutedIdxToExpandedIdx[permutedIdx] = expandedIdx; - } - if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocalExpert) - { - params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx; - } - } -} -#else -__global__ void routingIndicesCoopKernel(KernelParams params) -{ - assert(false && "routingIndicesCoopKernel is only supported on SM90+ architectures"); -} -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void launchCoopKernel(Data& data, int numBlocksCoop, int numThreadsHist, void* stream) -{ - LAUNCH_ROUTING_DEEPSEEK(data, - /*coopLaunch=*/true, routingIndicesCoopKernel, numBlocksCoop, numThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace routingDeepSeek -} // namespace moe::dev::routing diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingDeepSeek/launchHistogramKernel.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingDeepSeek/launchHistogramKernel.cu deleted file mode 100644 index 1263e289e13..00000000000 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingDeepSeek/launchHistogramKernel.cu +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "RoutingDeepSeekCommon.cuh" - -namespace moe::dev::routing -{ -namespace routingDeepSeek -{ - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void launchHistogramKernel(Data& data, int numBlocksHistogram, int numThreadsHist, void* stream) -{ - LAUNCH_ROUTING_DEEPSEEK(data, - /*coopLaunch=*/false, routingIndicesHistogramKernel, numBlocksHistogram, numThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace routingDeepSeek -} // namespace moe::dev::routing diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingDeepSeek/launchInitExpertCounts.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingDeepSeek/launchInitExpertCounts.cu deleted file mode 100644 index 5f265878a38..00000000000 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingDeepSeek/launchInitExpertCounts.cu +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "RoutingDeepSeekCommon.cuh" - -namespace moe::dev::routing -{ -namespace routingDeepSeek -{ - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void launchInitExpertCounts(Data& data, int numThreadsHist, void* stream) -{ - LAUNCH_ROUTING_DEEPSEEK(data, false, routingInitExpertCounts, (2 * data.mNumExperts - 1) / numThreadsHist + 1, - numThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/false); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace routingDeepSeek -} // namespace moe::dev::routing diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingDeepSeek/launchMainKernel.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingDeepSeek/launchMainKernel.cu deleted file mode 100644 index 1edc469cf70..00000000000 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingDeepSeek/launchMainKernel.cu +++ /dev/null @@ -1,289 +0,0 @@ -/* - * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "RoutingDeepSeekCommon.cuh" - -namespace moe::dev::routing -{ -namespace routingDeepSeek -{ - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__global__ void routingMainKernel(KernelParams params) -{ - // declare types - using OutputT = typename KernelParams::OutputT; - using InputT = typename KernelParams::InputT; - - // declare shared memory structure - // number of experts is bounded by number of threads - __shared__ float __attribute((aligned(128))) smemScoreSigmoid[KernelParams::MaxNumExperts]; - __shared__ float __attribute((aligned(128))) smemScoreBias[KernelParams::MaxNumExperts]; - // number of expert groups is bounded by number of warps - __shared__ float __attribute((aligned(128))) smemGroupScores[MaxNumGroups]; - - // needed for warp reduce - auto block = cg::this_thread_block(); - auto warp = cg::tiled_partition(block); - // for the final reduction of weight norm, only some lanes need to participate - int32_t laneIdx = threadIdx.x % WarpSize; - int32_t warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); - // warps outside the range of expert groups do not participate - if constexpr (KernelParams::UseGroups) - { - if (warpIdx >= params.mNumExpertGroups) - { - return; - } - } - - // note that for invalid scores, we simply use a negative value: - // they work well even with the compacted format used in topK, and - // sigmoid / bias activated scores cannot be negative - static constexpr float invalidScoreFloat = float{-INFINITY}; - const OutputT invalidScore = OutputT{invalidScoreFloat}; - - // load bias already; each warp represents one expert group - auto threadExpert = threadIdx.x; - bool expertSelected = threadExpert < params.mNumExperts; - if constexpr (KernelParams::UseGroups) - { - threadExpert = warpIdx * params.mNumExpertsPerGroup + laneIdx; - expertSelected = laneIdx < params.mNumExpertsPerGroup; - } - auto scoreIdx = int64_t{blockIdx.x} * int64_t{params.mNumExperts} + threadExpert; - auto biasVal = expertSelected ? params.mPtrRoutingBias[threadExpert] : invalidScore; - - // initialize the mPtrExpertCounts - if (params.mPtrExpertCounts) - { - int32_t globalThreadIdx = blockIdx.x * blockDim.x + threadIdx.x; - int32_t globalThreadStride = gridDim.x * blockDim.x; - int32_t expertCountsNum = 2 * params.mNumExperts; - initArr(globalThreadIdx, expertCountsNum, globalThreadStride, params.mPtrExpertCounts, 0); - } - -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) - // trigger the secondary kernel when using PDL, then wait on primary - if constexpr (KernelParams::UsePdl) - { - cudaTriggerProgrammaticLaunchCompletion(); - cudaGridDependencySynchronize(); - } -#endif - - if (params.mPtrScores != nullptr) - { - // get our assigned thread score; each warp represents one expert group - float score = expertSelected ? static_cast(params.mPtrScores[scoreIdx]) : invalidScoreFloat; - // get the sigmoid score - // note that for invalid values, we simply use a negative value: - // sigmoig scores are always strictly positive - auto scoreSigmoid = sigmoid_accurate(score); - // write the sigmoid score to shared for later use - if (expertSelected) - { - smemScoreSigmoid[threadExpert] = scoreSigmoid; - } - // get the score with bias - // note that with invalid values, because sigmoid is < 1 and bias is -1, - // we must get a negative value, which is smaller than any valid value - auto scoreBias = float{scoreSigmoid + float{biasVal}}; - - if (expertSelected) - { - smemScoreBias[threadExpert] = scoreBias; - } - - // registers for top group score reduction - float topExpGroupScores[NumTopGroupScores]; - [[maybe_unused]] int32_t topExpGroupIdx[NumTopGroupScores]; - float topGroups[MaxNumTopGroups]; // bound of params.mNumLimitedGroups - int32_t topGroupIdx[MaxNumTopGroups]; - float expertScoreGroup[MaxNumTopGroups]; - int32_t expertIdxGroup[MaxNumTopGroups]; - float topScores[KernelParams::MaxNumTopExperts]; // bound of params.mTopK - int32_t topExperts[KernelParams::MaxNumTopExperts]; - - if constexpr (KernelParams::UseGroups) - { - topk::reduceTopK(warp, topExpGroupScores, topExpGroupIdx, scoreBias, threadExpert, - /* minValue */ invalidScoreFloat); - // get the final group score and write it to shared - if (cute::elect_one_sync()) - { - auto groupScore = topExpGroupScores[0] + topExpGroupScores[1]; - smemGroupScores[warpIdx] = groupScore; - } - } - - // make group scores available to all warps - __syncthreads(); - - auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2; - if constexpr (KernelParams::UseGroups) - { // a single warp performs the selection of top groups, and goes on to select the final experts - if (warpIdx == 0) - { - float groupScore = laneIdx < params.mNumExpertGroups ? smemGroupScores[laneIdx] : invalidScoreFloat; - topk::reduceTopK(warp, topGroups, topGroupIdx, groupScore, laneIdx, - /* minValue */ invalidScoreFloat); - // final expert selection: get relevant indexes and scores from shared -#pragma unroll - for (int ii = 0; ii < MaxNumTopGroups; ++ii) - { // bound of params.mNumLimitedGroups - auto groupIdx = topGroupIdx[ii]; - expertIdxGroup[ii] = groupIdx * params.mNumExpertsPerGroup + laneIdx; - // note: expertSelected implies laneIdx < params.mNumExpertsPerGroup. - // we have params.mNumExpertsPerGroup == params.mNumExperts / params.mNumExpertGroups, - // thus groupIdx <= params.mNumExpertGroups - 1 => - // groupIdx * params.mNumExpertsPerGroup <= params.mNumExperts - params.mNumExpertsPerGroup - // => expertIdxGroup[ii] < params.mNumExperts <= NumThreads, - // so the access is safe here - expertScoreGroup[ii] - = (ii < params.mNumLimitedGroups) && (groupIdx < params.mNumExpertGroups) && expertSelected - ? smemScoreBias[expertIdxGroup[ii]] - : invalidScoreFloat; - } - - topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup, expertIdxGroup, - /* minValue */ invalidScoreFloat, params.mTopK); - } - } - else if constexpr (KernelParams::MaxNumExperts > topk::MaxNumExpertsUnit) - { - // without groups, each thread just takes `MaxNumTopGroups` experts - int constexpr NumExpertWarps = (KernelParams::MaxNumExperts - 1) / topk::MaxNumExpertsUnit + 1; - int constexpr NumInterTopK = NumExpertWarps * KernelParams::MaxNumTopExperts; - __shared__ float __attribute((aligned(128))) smemInterTopScores[NumInterTopK]; - __shared__ int32_t __attribute((aligned(128))) smemInterTopExperts[NumInterTopK]; - if (warpIdx < NumExpertWarps) - { - int offset = warpIdx * WarpSize * MaxNumTopGroups; -#pragma unroll - for (int ii = 0; ii < MaxNumTopGroups; ++ii) - { - auto expertIdx = ii * WarpSize + laneIdx; - expertIdxGroup[ii] = offset + expertIdx; - expertScoreGroup[ii] = offset + expertIdx < params.mNumExperts ? smemScoreBias[offset + expertIdx] - : invalidScoreFloat; - } - topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup, expertIdxGroup, - /* minValue */ invalidScoreFloat, params.mTopK); - - if (laneIdx < params.mTopK) - { - smemInterTopScores[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] = topScores[laneIdx]; - smemInterTopExperts[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] = topExperts[laneIdx]; - } - else if (laneIdx >= params.mTopK && laneIdx < KernelParams::MaxNumTopExperts) - { - smemInterTopScores[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] = invalidScoreFloat; - smemInterTopExperts[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] - = MaxSupportedExpertCount - 1; - } - } - __syncthreads(); - if (warpIdx == 0) - { - int constexpr NumInterTopKPerThread = (NumInterTopK - 1) / WarpSize + 1; - float intermediateScore[NumInterTopKPerThread]; - int32_t intermediateExpert[NumInterTopKPerThread]; - for (int i = laneIdx; i < NumInterTopKPerThread * WarpSize; i += WarpSize) - { - int ii = i / WarpSize; - if (i < NumInterTopK) - { - intermediateScore[ii] = smemInterTopScores[i]; - intermediateExpert[ii] = smemInterTopExperts[i]; - } - else - { - intermediateScore[ii] = invalidScoreFloat; - intermediateExpert[ii] = KernelParams::MaxNumExperts - 1; - } - } - topk::reduceTopK(warp, topScores, topExperts, intermediateScore, intermediateExpert, - /* minValue */ invalidScoreFloat, params.mTopK); - } - } - else - { - if (warpIdx == 0) - { - // without groups, each thread just takes `MaxNumTopGroups` experts -#pragma unroll - for (int ii = 0; ii < MaxNumTopGroups; ++ii) - { - auto expertIdx = ii * WarpSize + laneIdx; - expertIdxGroup[ii] = expertIdx; - expertScoreGroup[ii] - = expertIdx < params.mNumExperts ? smemScoreBias[expertIdx] : invalidScoreFloat; - } - topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup, expertIdxGroup, - /* minValue */ invalidScoreFloat, params.mTopK); - } - } - - if (warpIdx == 0) - { - // determine our lane's expert index and write to output - int32_t expertIdx = 0; -#pragma unroll - for (int ii = 0; ii < params.mTopK; ++ii) - { // bound of params.mTopK - expertIdx = laneIdx == ii ? topExperts[ii] : expertIdx; - } - // determine whether our expert is local to this GPU - auto localExpertIdx = expertIdx - params.mLocalExpertsStartIdx; - auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent - && (localExpertIdx & ((1 << params.mLocalExpertsStrideLog2) - 1)) == 0; - - float scoreNorm = laneIdx < params.mTopK ? smemScoreSigmoid[expertIdx] : 0.F; - auto redNorm = cg::reduce(warp, scoreNorm, cg::plus{}); - auto finalScore = OutputT{scoreNorm * params.mRouteScale / redNorm}; - - // write expert idx out already - auto idxTopK = blockIdx.x * params.mTopK + laneIdx; - if (laneIdx < params.mTopK && params.mPtrTopKPacked != nullptr) - { - PackedScoreIdx packedScore{static_cast(finalScore), static_cast(expertIdx)}; - params.mPtrTopKPacked[idxTopK] = packedScore; - } - - if (laneIdx < params.mTopK && params.mPtrTopKWeights != nullptr && params.mPtrTopKIds == nullptr) - { - params.mPtrTopKWeights[idxTopK] = finalScore; - } - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void launchMainKernel(Data& data, int numBlocks, int numThreadsMain, void* stream) -{ - LAUNCH_ROUTING_DEEPSEEK(data, - /*coopLaunch=*/false, routingMainKernel, numBlocks, numThreadsMain, - /*smemSize=*/0, // No dynamic smem - stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace routingDeepSeek -} // namespace moe::dev::routing diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingDeepSeek/launchOffsetsKernel.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingDeepSeek/launchOffsetsKernel.cu deleted file mode 100644 index 0836c21aa5c..00000000000 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingDeepSeek/launchOffsetsKernel.cu +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "RoutingDeepSeekCommon.cuh" - -namespace moe::dev::routing -{ -namespace routingDeepSeek -{ - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void launchOffsetsKernel(Data& data, int numBlocksOffsets, int numThreadsHist, void* stream) -{ - LAUNCH_ROUTING_DEEPSEEK(data, - /*coopLaunch=*/false, routingIndicesOffsetsKernel, numBlocksOffsets, numThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace routingDeepSeek -} // namespace moe::dev::routing diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingRenormalize/RoutingRenormalizeCommon.cuh b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingRenormalize/RoutingRenormalizeCommon.cuh deleted file mode 100644 index 31bfb399547..00000000000 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingRenormalize/RoutingRenormalizeCommon.cuh +++ /dev/null @@ -1,162 +0,0 @@ -/* - * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "../RoutingKernel.cuh" - -namespace moe::dev::routing -{ -namespace routingRenormalize -{ -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static constexpr int NumExperts128Experts = 128; -static constexpr int NumExperts512Experts = 512; -static constexpr int MaxSupportedExperts = 2048; - -static constexpr int NumTop8Experts = 8; -static constexpr int NumTop16Experts = 16; -static constexpr int MaxSupportedTopExperts = 32; - -static constexpr int NumThreads = 1024; -static constexpr int NumWarps = NumThreads / WarpSize; - -static constexpr int MaxNumTokensSingleCluster = NumBlocksPerCluster * NumThreads; -static constexpr int MaxNumTokensSingleClusterScores = NumBlocksPerCluster * NumWarps; - -static constexpr int BlockKernelMaxNumTokens = 4; -static constexpr int DynBlockKernelMaxNumTokens = 16; -static constexpr int DynBlockKernelMaxNumExperts = 512; - -template -__forceinline__ __device__ void routingTopKExperts(cg::thread_block_tile const& warp, - DataType (&score)[VecSize], int32_t (&idx)[VecSize], DataType (&warpTopKScore)[K], int32_t (&warpTopKExpertIdx)[K], - int32_t const laneIdx, int32_t const numExperts, int32_t topK, InputType const* ptrScores, bool const normTopkProb, - bool const applySoftmaxAfterTopK = true) -{ - DataType minScore = DataType{-INFINITY}; - - for (int i = 0; i < VecSize; i++) - { - auto expertIdx = i * WarpSize + laneIdx; - auto newScore = expertIdx < numExperts ? static_cast(ptrScores[expertIdx]) : minScore; - score[i] = newScore; - idx[i] = expertIdx; - } - if constexpr (DoSoftmaxBeforeTopK) - { - calcSoftmax(warp, score); - } - - // Get the top-k scores and their corresponding expert indices - topk::reduceTopK(warp, warpTopKScore, warpTopKExpertIdx, score, idx, minScore, topK); - - // Normalize the scores - if constexpr (DoSoftmaxBeforeTopK) - { - float sum = float{1.f}; - if (normTopkProb) - { - sum = static_cast(laneIdx < topK ? warpTopKScore[laneIdx] : 0); - sum = cg::reduce(warp, sum, cg::plus()); - } - if (laneIdx < topK) - { - warpTopKScore[laneIdx] = warpTopKScore[laneIdx] / sum; - } - } - else - { - if (applySoftmaxAfterTopK) - { - auto softmaxScore = calcSoftmax(warp, laneIdx < topK ? warpTopKScore[laneIdx] : minScore, laneIdx, topK); - if (laneIdx < topK) - { - warpTopKScore[laneIdx] = softmaxScore; - } - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -int32_t constexpr getMaxNumExperts(int32_t numExperts) -{ - if (numExperts <= NumExperts128Experts) - { - return NumExperts128Experts; - } - else if (numExperts <= NumExperts512Experts) - { - return NumExperts512Experts; - } - else if (numExperts <= MaxSupportedExperts) - { - return MaxSupportedExperts; - } - else - { - TLLM_LOG_ERROR("Unsupported numExperts"); - return 0; - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// Helper macro: dispatch on topK tier for a given numExperts tier. -#define LAUNCH_ROUTING_WITH_TOPK( \ - data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, numExperts) \ - if (data.mTopK <= NumTop8Experts) \ - { \ - LAUNCH_ROUTING_WITH_NUM_EXPERTS(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, \ - numExperts, NumTop8Experts); \ - } \ - else if (data.mTopK <= NumTop16Experts) \ - { \ - LAUNCH_ROUTING_WITH_NUM_EXPERTS(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, \ - numExperts, NumTop16Experts); \ - } \ - else \ - { \ - LAUNCH_ROUTING_WITH_NUM_EXPERTS(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, \ - numExperts, MaxSupportedTopExperts); \ - } - -#define LAUNCH_ROUTING_RENORMALIZE(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1) \ - if (data.mNumExperts <= NumExperts128Experts) \ - { \ - LAUNCH_ROUTING_WITH_TOPK( \ - data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, NumExperts128Experts); \ - } \ - else if (data.mNumExperts <= NumExperts512Experts) \ - { \ - LAUNCH_ROUTING_WITH_TOPK( \ - data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, NumExperts512Experts); \ - } \ - else if (data.mNumExperts <= MaxSupportedExperts) \ - { \ - LAUNCH_ROUTING_WITH_TOPK( \ - data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, MaxSupportedExperts); \ - } \ - else \ - { \ - TLLM_LOG_ERROR("Unsupported numExperts"); \ - } - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace routingRenormalize -} // namespace moe::dev::routing diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingRenormalize/launchClusterKernel.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingRenormalize/launchClusterKernel.cu deleted file mode 100644 index b8d7f8b9118..00000000000 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingRenormalize/launchClusterKernel.cu +++ /dev/null @@ -1,117 +0,0 @@ -/* - * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "RoutingRenormalizeCommon.cuh" - -namespace moe::dev::routing -{ -namespace routingRenormalize -{ - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) -__global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(NumThreads) - routingIndicesClusterKernel(KernelParams params) -{ - // number of tokens/expanded idx is bounded by total number of warps - using OutputT = typename KernelParams::OutputT; - using InputT = typename KernelParams::InputT; - - using BaseType = std::conditional_t; - using TypePacked = PackedScoreIdx; - - static constexpr int VecSize = KernelParams::MaxNumExperts / WarpSize; - - __shared__ TypePacked __attribute((aligned(128))) smemPackedScoreIdx[NumWarps * KernelParams::MaxNumTopExperts]; - - uint32_t const clusterBlockRank = blockIdx.x; - - int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); - int32_t const laneIdx = cutlass::arch::LaneId(); - - auto warpTokenIdx = clusterBlockRank * NumWarps + warpIdx; - auto scoreOffset = warpTokenIdx * params.mNumExperts; - bool validToken = warpTokenIdx < params.mNumTokens; - - auto block = cg::this_thread_block(); - auto warp = cg::tiled_partition(block); - - // then wait on primary grid - if constexpr (KernelParams::UsePdl) - { - cudaGridDependencySynchronize(); - } - - if (params.mPtrScores != nullptr) - { - // in this case, each warp represents a token - BaseType score[VecSize]; - int32_t idx[VecSize]; - - BaseType warpTopKScore[KernelParams::MaxNumTopExperts]; - int32_t warpTopKExpertIdx[KernelParams::MaxNumTopExperts]; - - BaseType minScore = BaseType{-INFINITY}; - if (validToken) - { - routingTopKExperts(warp, score, idx, warpTopKScore, warpTopKExpertIdx, laneIdx, - params.mNumExperts, params.mTopK, params.mPtrScores + scoreOffset, params.mNormTopkProb); - - if (laneIdx < params.mTopK) - { - smemPackedScoreIdx[warpIdx * params.mTopK + laneIdx] - = TypePacked{warpTopKScore[laneIdx], static_cast(warpTopKExpertIdx[laneIdx])}; - } - } // end if (validToken) - } - - // make packed scores available to all threads in cluster - __cluster_barrier_arrive(); - __cluster_barrier_wait(); - - if (params.mPtrScores != nullptr) - { - routingPermutation(params, smemPackedScoreIdx, warpIdx, clusterBlockRank); - } - else - { - routingPermutation(params, smemPackedScoreIdx, warpIdx, clusterBlockRank); - } -} -#else -__global__ void __launch_bounds__(NumThreads) routingIndicesClusterKernel(KernelParams /* params */) -{ - assert(false && "routingIndicesClusterKernel is only supported on SM90+ architectures"); -} -#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void launchClusterKernel(Data const& data, void* stream) -{ - LAUNCH_ROUTING_RENORMALIZE(data, false, routingIndicesClusterKernel, NumBlocksPerCluster, NumThreads, - /*smemSize=*/0, // No dynamic smem - stream, data.mDoSoftmaxBeforeTopK); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace routingRenormalize -} // namespace moe::dev::routing diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingRenormalize/launchHistogramKernel.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingRenormalize/launchHistogramKernel.cu deleted file mode 100644 index 7d6f6177a56..00000000000 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingRenormalize/launchHistogramKernel.cu +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "RoutingRenormalizeCommon.cuh" - -namespace moe::dev::routing -{ -namespace routingRenormalize -{ - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void launchHistogramKernel(Data const& data, int numBlocksHistogram, uint32_t numThreadsHist, void* stream) -{ - LAUNCH_ROUTING_RENORMALIZE(data, false, routingIndicesHistogramKernel, numBlocksHistogram, numThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream, data.mDoSoftmaxBeforeTopK); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace routingRenormalize -} // namespace moe::dev::routing diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingRenormalize/launchHistogramScoresKernel.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingRenormalize/launchHistogramScoresKernel.cu deleted file mode 100644 index 03bec526f82..00000000000 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingRenormalize/launchHistogramScoresKernel.cu +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "RoutingRenormalizeCommon.cuh" - -namespace moe::dev::routing -{ -namespace routingRenormalize -{ - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// this kernel is needed in case we have scores as input for the histogram kernel -template -__global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelParams::MaxNumExperts : 1024) - routingIndicesHistogramScoresKernel(KernelParams params) -{ - using OutputT = typename KernelParams::OutputT; - using InputT = typename KernelParams::InputT; - using BaseType = std::conditional_t; - // Cap actual thread count at 1024 when MaxNumExperts > 1024. - static constexpr int NumThreadsBlock = KernelParams::MaxNumExperts <= 1024 ? KernelParams::MaxNumExperts : 1024; - - // VecSize stays based on MaxNumExperts — each warp still processes all experts for one token. - static constexpr int VecSize = KernelParams::MaxNumExperts / WarpSize; - - int32_t const laneIdx = cutlass::arch::LaneId(); - int32_t const warpIdx = threadIdx.x / WarpSize; - // Use NumThreadsBlock (actual thread count) for grid-stride warp/thread addressing - int32_t const globalWarpIdx = blockIdx.x * NumThreadsBlock / WarpSize + warpIdx; - int32_t const globalWarpStride = gridDim.x * NumThreadsBlock / WarpSize; - BaseType minScore = BaseType{-INFINITY}; - auto block = cg::this_thread_block(); - auto warp = cg::tiled_partition(block); - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - // Wait on primary grid. - if constexpr (KernelParams::UsePdl) - { - cudaGridDependencySynchronize(); - } -#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - - // initialize the mPtrExpertCounts — use NumThreadsBlock for grid-stride - int32_t expertCountsNum = 2 * params.mNumExperts; - int32_t globalThreadIdx = blockIdx.x * NumThreadsBlock + threadIdx.x; - int32_t globalThreadStride = gridDim.x * NumThreadsBlock; - initArr(globalThreadIdx, expertCountsNum, globalThreadStride, params.mPtrExpertCounts, 0); - - // in this case, each warp represents a token, and we use a grid-stride loop - // over all warps/tokens - BaseType allScores[VecSize]; - int32_t allExpertIdx[VecSize]; - BaseType warpTopKScore[KernelParams::MaxNumTopExperts]; - int32_t warpTopKExpertIdx[KernelParams::MaxNumTopExperts]; - for (int tokenIdx = globalWarpIdx; tokenIdx < params.mNumTokens; tokenIdx += globalWarpStride) - { - auto scoreOffset = tokenIdx * params.mNumExperts; - - routingTopKExperts(warp, allScores, allExpertIdx, warpTopKScore, warpTopKExpertIdx, laneIdx, - params.mNumExperts, params.mTopK, params.mPtrScores + scoreOffset, params.mNormTopkProb); - - if (laneIdx < params.mTopK) - { - PackedScoreIdx packedScore{ - static_cast(warpTopKScore[laneIdx]), static_cast(warpTopKExpertIdx[laneIdx])}; - params.mPtrTopKPacked[tokenIdx * params.mTopK + laneIdx] = packedScore; - } - } - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - // Trigger secondary kernel AFTER writing all packed scores, so the next kernel - // (routingIndicesHistogramKernel) sees the completed mPtrTopKPacked writes. - if constexpr (KernelParams::UsePdl) - { - cudaTriggerProgrammaticLaunchCompletion(); - } -#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void launchHistogramScoresKernel(Data const& data, uint32_t maxNumBlocks, uint32_t numThreadsHist, void* stream) -{ - LAUNCH_ROUTING_RENORMALIZE(data, false, routingIndicesHistogramScoresKernel, maxNumBlocks, numThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream, data.mDoSoftmaxBeforeTopK); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace routingRenormalize -} // namespace moe::dev::routing diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingRenormalize/launchInitExpertCounts.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingRenormalize/launchInitExpertCounts.cu deleted file mode 100644 index 807fc89e897..00000000000 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingRenormalize/launchInitExpertCounts.cu +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "RoutingRenormalizeCommon.cuh" - -namespace moe::dev::routing -{ -namespace routingRenormalize -{ - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void launchInitExpertCounts(Data const& data, uint32_t numThreadsHist, void* stream) -{ - LAUNCH_ROUTING_RENORMALIZE(data, false, routingInitExpertCounts, (2 * data.mNumExperts - 1) / numThreadsHist + 1, - numThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream, data.mDoSoftmaxBeforeTopK); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace routingRenormalize -} // namespace moe::dev::routing diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingRenormalize/launchOffsetsKernel.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingRenormalize/launchOffsetsKernel.cu deleted file mode 100644 index fe398c80cdb..00000000000 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/routingRenormalize/launchOffsetsKernel.cu +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "RoutingRenormalizeCommon.cuh" - -namespace moe::dev::routing -{ -namespace routingRenormalize -{ - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void launchOffsetsKernel(Data const& data, int numBlocksOffsets, uint32_t numThreadsHist, void* stream) -{ - LAUNCH_ROUTING_RENORMALIZE(data, false, routingIndicesOffsetsKernel, numBlocksOffsets, numThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream, data.mDoSoftmaxBeforeTopK); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace routingRenormalize -} // namespace moe::dev::routing diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu index 467bca9318a..150b729aa54 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu @@ -15,7 +15,7 @@ */ #include "DevKernel.h" -#include "RoutingKernel.h" +#include "routing/RoutingKernel.h" #include "runner.h" #include "tensorrt_llm/common/config.h" #include "tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/KernelRunner.h" @@ -56,7 +56,7 @@ inline int32_t computeLog2(int32_t val, std::string const& name = "") Runner::Runner() {} -Runner::Runner(int32_t tileTokensDim) +Runner::Runner(int32_t tileTokensDim, int32_t clusterSizeInBatchDim) : mTileTokensDim(tileTokensDim) { } @@ -67,15 +67,175 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 int32_t* expandedIdxToPermutedIdx, int32_t* permutedIdxToExpandedIdx, int32_t* permutedIdxToTokenIdx, void* expertWeights, int32_t* expertIds, int32_t* numTokensPerExpert, int32_t* ctaIdxXyToBatchIdx, int32_t* ctaIdxXyToMnLimit, int32_t* numNonExitingCtas, btg::Dtype dtypeElt, bool useRoutingScalesOnInput, - bool useDeepSeekFp8, RoutingMethodType routingMethodType, cudaStream_t stream) + bool useDeepSeekFp8, RoutingMethodType routingMethodType, cudaStream_t stream, btg::Dtype dtypeRoutingLogits) { - if (routingMethodType == RoutingMethodType::DeepSeekV3) + if (routingMethodType == RoutingMethodType::DeepSeekV3 && nGroup <= 1) + { + // DeepSeek no-groups case: use routingCustom with SigmoidBias preprocess + // and ScaledSumNormalize postprocess. This is more efficient than the full DeepSeek + // kernel because it uses the warp-level routingTopKExperts flow. + moe::dev::routing::routingCustom::Data routingData; + + // + // Config + // + routingData.mDtypeOutput = btg::Dtype::Bfloat16; + routingData.mDtypeInput = dtypeRoutingLogits; + routingData.mUsePdl = tensorrt_llm::common::getEnvEnablePDL(); + routingData.mPreprocessType = moe::dev::routing::RoutingPreprocessType::SigmoidBias; + routingData.mPostprocessType = moe::dev::routing::RoutingPostprocessType::ScaledSumNormalize; + routingData.mPtrRoutingBias = routingBias; + // Bias is always bfloat16 in the current Runner::run() API (no separate bias dtype param). + // The bias buffer dtype is determined by the caller (thop), not by the routing logits dtype. + routingData.mDtypeBias = btg::Dtype::Bfloat16; + routingData.mRouteScale = routedScalingFactor; + + // Pass-through raw pointer; kernels will cast to the proper InputT based on routing method + routingData.mPtrScores = expertIds == nullptr ? routingLogits : nullptr; + // + // Outputs + // + routingData.mPtrTopKPacked = routingExpertIndexes; + routingData.mPtrExpertCounts = expertCountHistogram; + routingData.mPtrPermutedIdxSize = permutedIdxSize; + routingData.mPtrExpandedIdxToPermutedIdx = expandedIdxToPermutedIdx; + routingData.mPtrPermutedIdxToExpandedIdx = permutedIdxToExpandedIdx; + routingData.mPtrPermutedIdxToTokenIdx = permutedIdxToTokenIdx; + routingData.mPtrTopKWeights = expertWeights; + routingData.mPtrTopKIds = expertIds; + // + // Grouped Gemm Launch Config Buffers + // + routingData.mPtrCtaIdxXyToBatchIdx = ctaIdxXyToBatchIdx; + routingData.mPtrCtaIdxXyToMnLimit = ctaIdxXyToMnLimit; + routingData.mPtrNumNonExitingCtas = numNonExitingCtas; + + // + // Inputs + // + routingData.mNumTokens = numTokens; + routingData.mNumExperts = numExperts; + routingData.mTopK = topK; + routingData.mPaddingLog2 = computeLog2(mTileTokensDim); + routingData.mTileTokensDim = mTileTokensDim; + routingData.mLocalExpertsStartIdx = localExpertOffset; + routingData.mLocalExpertsStrideLog2 = 0; + routingData.mNumLocalExperts = localNumExperts; + + moe::dev::routing::routingCustom::run(routingData, stream); + } + else if (routingMethodType == RoutingMethodType::SigmoidRenorm) + { + // SigmoidRenorm: sigmoid(logit) → topK → renormalize. + // No bias, no scaling factor — pure sigmoid activation with top-K renormalization. + moe::dev::routing::routingCustom::Data routingData; + + // + // Config + // + routingData.mDtypeOutput = btg::Dtype::Bfloat16; + routingData.mDtypeInput = dtypeRoutingLogits; + routingData.mUsePdl = tensorrt_llm::common::getEnvEnablePDL(); + routingData.mPreprocessType = moe::dev::routing::RoutingPreprocessType::Sigmoid; + routingData.mPostprocessType = moe::dev::routing::RoutingPostprocessType::SumNormalize; + routingData.mNormTopkProb = true; + + // Pass-through raw pointer; kernels will cast to the proper InputT based on routing method + routingData.mPtrScores = expertIds == nullptr ? routingLogits : nullptr; + // + // Outputs + // + routingData.mPtrTopKPacked = routingExpertIndexes; + routingData.mPtrExpertCounts = expertCountHistogram; + routingData.mPtrPermutedIdxSize = permutedIdxSize; + routingData.mPtrExpandedIdxToPermutedIdx = expandedIdxToPermutedIdx; + routingData.mPtrPermutedIdxToExpandedIdx = permutedIdxToExpandedIdx; + routingData.mPtrPermutedIdxToTokenIdx = permutedIdxToTokenIdx; + routingData.mPtrTopKWeights = expertWeights; + routingData.mPtrTopKIds = expertIds; + // + // Grouped Gemm Launch Config Buffers + // + routingData.mPtrCtaIdxXyToBatchIdx = ctaIdxXyToBatchIdx; + routingData.mPtrCtaIdxXyToMnLimit = ctaIdxXyToMnLimit; + routingData.mPtrNumNonExitingCtas = numNonExitingCtas; + + // + // Inputs + // + routingData.mNumTokens = numTokens; + routingData.mNumExperts = numExperts; + routingData.mTopK = topK; + routingData.mPaddingLog2 = computeLog2(mTileTokensDim); + routingData.mTileTokensDim = mTileTokensDim; + routingData.mLocalExpertsStartIdx = localExpertOffset; + routingData.mLocalExpertsStrideLog2 = 0; + routingData.mNumLocalExperts = localNumExperts; + + moe::dev::routing::routingCustom::run(routingData, stream); + } + else if (routingMethodType == RoutingMethodType::MiniMax2) + { + // MiniMaxM2: sigmoid(logit) + bias → topK → renormalize un-biased sigmoid scores. + // Similar to DeepSeek no-groups but with routeScale = 1.0 and epsilon = 1e-20 + // to match the Python reference: weight / (sum + 1e-20). + moe::dev::routing::routingCustom::Data routingData; + + // + // Config + // + routingData.mDtypeOutput = btg::Dtype::Bfloat16; + routingData.mDtypeInput = dtypeRoutingLogits; + routingData.mUsePdl = tensorrt_llm::common::getEnvEnablePDL(); + routingData.mPreprocessType = moe::dev::routing::RoutingPreprocessType::SigmoidBias; + routingData.mPostprocessType = moe::dev::routing::RoutingPostprocessType::ScaledSumNormalize; + routingData.mPtrRoutingBias = routingBias; + // Bias is always bfloat16 in the current Runner::run() API (no separate bias dtype param). + routingData.mDtypeBias = btg::Dtype::Bfloat16; + routingData.mRouteScale = 1.0f; + routingData.mSumEpsilon = 1e-20f; + + // Pass-through raw pointer; kernels will cast to the proper InputT based on routing method + routingData.mPtrScores = expertIds == nullptr ? routingLogits : nullptr; + // + // Outputs + // + routingData.mPtrTopKPacked = routingExpertIndexes; + routingData.mPtrExpertCounts = expertCountHistogram; + routingData.mPtrPermutedIdxSize = permutedIdxSize; + routingData.mPtrExpandedIdxToPermutedIdx = expandedIdxToPermutedIdx; + routingData.mPtrPermutedIdxToExpandedIdx = permutedIdxToExpandedIdx; + routingData.mPtrPermutedIdxToTokenIdx = permutedIdxToTokenIdx; + routingData.mPtrTopKWeights = expertWeights; + routingData.mPtrTopKIds = expertIds; + // + // Grouped Gemm Launch Config Buffers + // + routingData.mPtrCtaIdxXyToBatchIdx = ctaIdxXyToBatchIdx; + routingData.mPtrCtaIdxXyToMnLimit = ctaIdxXyToMnLimit; + routingData.mPtrNumNonExitingCtas = numNonExitingCtas; + + // + // Inputs + // + routingData.mNumTokens = numTokens; + routingData.mNumExperts = numExperts; + routingData.mTopK = topK; + routingData.mPaddingLog2 = computeLog2(mTileTokensDim); + routingData.mTileTokensDim = mTileTokensDim; + routingData.mLocalExpertsStartIdx = localExpertOffset; + routingData.mLocalExpertsStrideLog2 = 0; + routingData.mNumLocalExperts = localNumExperts; + + moe::dev::routing::routingCustom::run(routingData, stream); + } + else if (routingMethodType == RoutingMethodType::DeepSeekV3) { TLLM_CHECK_WITH_INFO(topK <= 22, "For DeepSeek routing method, must have topK <= 22"); TLLM_CHECK_WITH_INFO(topkGroup <= 4, "For DeepSeek routing method, must have topkGroup <= 4"); moe::dev::routing::routingDeepSeek::Data routingData; - routingData.mDtypeExpW = btg::Dtype::Bfloat16; - routingData.mUsePdl = true; + routingData.mDtypeOutput = btg::Dtype::Bfloat16; + routingData.mUsePdl = tensorrt_llm::common::getEnvEnablePDL(); // output: routingData.mPtrTopKPacked = routingExpertIndexes; @@ -92,6 +252,8 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 // input: routingData.mPtrRoutingBias = routingBias; + // Bias is always bfloat16 in the current Runner::run() API (no separate bias dtype param). + routingData.mDtypeBias = btg::Dtype::Bfloat16; // Pass-through raw pointer; kernels will cast to the proper InputT based on routing method routingData.mPtrScores = expertIds == nullptr ? routingLogits : nullptr; routingData.mPtrTopKIds = expertIds; @@ -117,8 +279,8 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 TLLM_LOG_WARNING("For Llama routing method, nGroup/topkGroup is ignored, got %d/%d.", nGroup, topkGroup); } moe::dev::routing::routingLlama4::Data routingData; - routingData.mDtypeExpW = btg::Dtype::Bfloat16; - routingData.mUsePdl = true; + routingData.mDtypeOutput = btg::Dtype::Bfloat16; + routingData.mUsePdl = tensorrt_llm::common::getEnvEnablePDL(); // output: routingData.mPtrTopKPacked = routingExpertIndexes; @@ -157,20 +319,43 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 // routingData.mUseRoutingSoftmax = false; moe::dev::routing::routingLlama4::run(routingData, stream); } - else if (routingMethodType == RoutingMethodType::Renormalize /* default */ - || routingMethodType == RoutingMethodType::RenormalizeNaive /* Softmax -> TopK */) + else if (routingMethodType == RoutingMethodType::Renormalize + || routingMethodType == RoutingMethodType::RenormalizeNaive || routingMethodType == RoutingMethodType::Default) { - moe::dev::routing::routingRenormalize::Data routingData; + moe::dev::routing::routingCustom::Data routingData; // // Config // - routingData.mDtypeExpW = btg::Dtype::Bfloat16; + routingData.mDtypeOutput = btg::Dtype::Bfloat16; + routingData.mDtypeInput = dtypeRoutingLogits; // routingData.mDtypeElt = dtypeElt; // no-op for now as hidden_state is not input - routingData.mUsePdl = tensorrt_llm::common::getEnvEnableTrtllmgenMoeRoutingRenormPDL(); - routingData.mDoSoftmaxBeforeTopK = routingMethodType == RoutingMethodType::RenormalizeNaive; - routingData.mNormTopkProb = routingMethodType == RoutingMethodType::RenormalizeNaive; + routingData.mUsePdl = tensorrt_llm::common::getEnvEnablePDL(); + if (routingMethodType == RoutingMethodType::Default) + { + // Default: Softmax -> TopK (no postprocessing) + routingData.mPreprocessType = moe::dev::routing::RoutingPreprocessType::Softmax; + routingData.mPostprocessType = moe::dev::routing::RoutingPostprocessType::None; + } + else + { + // Renormalize and RenormalizeNaive are mathematically equivalent: + // RenormalizeNaive: softmax(all N experts) → topK → divide by sum of topK + // Renormalize: topK(raw scores) → softmax(K experts) + // + // Both produce identical output because: + // 1. softmax is monotonic, so topK selection yields the same experts + // 2. softmax(topK raw scores) = softmax(topK softmax scores) after renormalization, + // since softmax(x_i) / Σ softmax(x_j) = exp(x_i) / Σ exp(x_j) for the topK subset + // + // We always use the Renormalize path (NoOp preprocess + Softmax postprocess) + // because it only computes softmax over K experts instead of all N, which is faster + // — especially for large expert counts (e.g., 256 experts with topK=8). + routingData.mPreprocessType = moe::dev::routing::RoutingPreprocessType::None; + routingData.mPostprocessType = moe::dev::routing::RoutingPostprocessType::Softmax; + routingData.mNormTopkProb = true; + } // Pass-through raw pointer; kernels will cast to the proper InputT based on routing method routingData.mPtrScores = expertIds == nullptr ? routingLogits : nullptr; @@ -204,7 +389,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 routingData.mLocalExpertsStrideLog2 = 0; routingData.mNumLocalExperts = localNumExperts; - moe::dev::routing::routingRenormalize::run(routingData, stream); + moe::dev::routing::routingCustom::run(routingData, stream); } else { @@ -291,12 +476,12 @@ void Runner::run(void* hiddenState, void* hiddenStateScale, void* weights, void* // tensorrt_llm/_torch/modules/fused_moe/quantization.py:MXFP4WeightTRTLLMGenFusedMoEMethod.input_hidden_alignment validHiddenSize = tensorrt_llm::common::roundUp(validHiddenSize, 512); } - auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); + auto maxNumCgasInBatchDim = Routing::getMaxNumCgasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); bool is_gated_activation = mActType == ActType::SwiGlu; int32_t intermediateSizeFactor = (is_gated_activation ? 2 : 1); mRunner.run(numTokens, intermediateSizeFactor * intermediateSize, hiddenSize, numTokens, intermediateSizeFactor * validIntermediateSize, validHiddenSize, {}, numTokens, numExperts, - maxNumCtasInBatchDim, hiddenState, hiddenStateScale, weights, weightsScale, + maxNumCgasInBatchDim, hiddenState, hiddenStateScale, weights, weightsScale, useRoutingScalesOnInput ? expertWeights : nullptr, /* perTokensSfB */ nullptr, outputScalesScalar, outputScalesGateScalar, ptrBias, ptrAlpha, ptrBeta, ptrClampLimit, output, outputScale, permutedIdxToTokenIdx, ptrTotalNumPaddedTokens, ptrCtaIdxXyToBatchIdx, ptrCtaIdxXyToMnLimit, ptrNumNonExitingCtas, bmm1Workspace, @@ -306,31 +491,31 @@ void Runner::run(void* hiddenState, void* hiddenStateScale, void* weights, void* size_t Runner::getWorkspaceSizeInBytes(int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numExperts, int32_t numTokens, int32_t configIndex) const { - auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); + auto maxNumCgasInBatchDim = Routing::getMaxNumCgasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); int32_t const intermediateSizeFactor = mActType == ActType::SwiGlu ? 2 : 1; return mRunner.getWorkspaceSizeInBytes(numTokens, intermediateSizeFactor * intermediateSize, hiddenSize, {}, - numTokens, numExperts, maxNumCtasInBatchDim, configIndex); + numTokens, numExperts, maxNumCgasInBatchDim, configIndex); } int32_t Runner::getDefaultValidConfigIndex(int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numExperts, int32_t numTokens, int32_t validHiddenSize, int32_t validIntermediateSize) const { - auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); + auto maxNumCgasInBatchDim = Routing::getMaxNumCgasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); bool is_gated_activation = mActType == ActType::SwiGlu; return mRunner.getDefaultValidConfigIndex(numTokens, is_gated_activation ? 2 * intermediateSize : intermediateSize, - hiddenSize, {}, numTokens, numExperts, maxNumCtasInBatchDim, numTokens, 2 * validIntermediateSize, + hiddenSize, {}, numTokens, numExperts, maxNumCgasInBatchDim, numTokens, 2 * validIntermediateSize, validHiddenSize); } bool Runner::isValidConfigIndex(int32_t configIndex, int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numExperts, int32_t numTokens, int32_t validHiddenSize, int32_t validIntermediateSize) const { - auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); + auto maxNumCgasInBatchDim = Routing::getMaxNumCgasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); bool is_gated_activation = mActType == ActType::SwiGlu; auto const isValid = mRunner.isValidConfigIndex(configIndex, numTokens, is_gated_activation ? 2 * intermediateSize : intermediateSize, hiddenSize, {}, numTokens, numExperts, - maxNumCtasInBatchDim, numTokens, 2 * validIntermediateSize, validHiddenSize); + maxNumCgasInBatchDim, numTokens, 2 * validIntermediateSize, validHiddenSize); return isValid; } @@ -391,9 +576,9 @@ void Runner::run(void* permutedHiddenState, void* permutedHiddenStateScale, void // The multiple is no less than 128 as TMA requires it for CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B types validIntermediateSize = tensorrt_llm::common::roundUp(validIntermediateSize, 128); } - auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); + auto maxNumCgasInBatchDim = Routing::getMaxNumCgasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); mRunner.run(numTokens, hiddenSize, intermediateSize, numTokens, validHiddenSize, validIntermediateSize, {}, - numTokens, numExperts, maxNumCtasInBatchDim, permutedHiddenState, permutedHiddenStateScale, weights, + numTokens, numExperts, maxNumCgasInBatchDim, permutedHiddenState, permutedHiddenStateScale, weights, weightsScale, /* perTokensSfA */ nullptr, /* perTokensSfB */ nullptr, outputScalesScalar, /* outputScalesGateScalar */ nullptr, ptrBias, /* ptrAlpha */ nullptr, /* ptrBeta */ nullptr, /* clampLimit */ nullptr, output, outputScale, @@ -404,27 +589,27 @@ void Runner::run(void* permutedHiddenState, void* permutedHiddenStateScale, void size_t Runner::getWorkspaceSizeInBytes(int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numExperts, int32_t numTokens, int32_t configIndex) const { - auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); + auto maxNumCgasInBatchDim = Routing::getMaxNumCgasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); return mRunner.getWorkspaceSizeInBytes( - numTokens, hiddenSize, intermediateSize, {}, numTokens, numExperts, maxNumCtasInBatchDim, configIndex); + numTokens, hiddenSize, intermediateSize, {}, numTokens, numExperts, maxNumCgasInBatchDim, configIndex); } int32_t Runner::getDefaultValidConfigIndex(int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numExperts, int32_t numTokens, int32_t validHiddenSize, int32_t validIntermediateSize) const { - auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); + auto maxNumCgasInBatchDim = Routing::getMaxNumCgasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); return mRunner.getDefaultValidConfigIndex(numTokens, hiddenSize, intermediateSize, {}, numTokens, numExperts, - maxNumCtasInBatchDim, numTokens, validHiddenSize, validIntermediateSize); + maxNumCgasInBatchDim, numTokens, validHiddenSize, validIntermediateSize); } bool Runner::isValidConfigIndex(int32_t configIndex, int32_t topK, int32_t hiddenSize, int32_t intermediateSize, int32_t numExperts, int32_t numTokens, int32_t validHiddenSize, int32_t validIntermediateSize) const { - auto const maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); + auto const maxNumCgasInBatchDim = Routing::getMaxNumCgasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); auto const isValid = mRunner.isValidConfigIndex(configIndex, numTokens, hiddenSize, intermediateSize, {}, numTokens, - numExperts, maxNumCtasInBatchDim, numTokens, validHiddenSize, validIntermediateSize); + numExperts, maxNumCgasInBatchDim, numTokens, validHiddenSize, validIntermediateSize); return isValid; } @@ -482,11 +667,11 @@ void Runner::setOpsData(MoERunnerArgs const& args, MoEWorkspace const& workspace convertSfData.numTokens = args.num_tokens; convertSfData.sfLayoutSrc = btg::SfLayout::R128c4; convertSfData.sfLayoutDst = btg::SfLayout::Linear; - convertSfData.mUsePdl = true; + convertSfData.mUsePdl = tensorrt_llm::common::getEnvEnablePDL(); // Setup activation data activationData.mDtypeElt = args.mDtypeElt; - activationData.mUsePdl = true; + activationData.mUsePdl = tensorrt_llm::common::getEnvEnablePDL(); activationData.mUseDeepSeekFp8 = true; activationData.inPtr = workspace.gemm1_output; activationData.outPtr = workspace.activation_output; @@ -504,7 +689,7 @@ void Runner::setOpsData(MoERunnerArgs const& args, MoEWorkspace const& workspace // Setup finalize data finalizeData.mDtypeElt = args.mDtypeOut; finalizeData.mDtypeExpW = args.mDtypeExpW; - finalizeData.mUsePdl = true; + finalizeData.mUsePdl = false; finalizeData.mUseDeepSeekFp8 = false; finalizeData.inPtr = workspace.gemm2_output; finalizeData.outPtr = args.output; diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.h b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.h index e922764061d..ba20fca4daf 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.h +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.h @@ -17,7 +17,7 @@ #pragma once #include "DevKernel.h" -#include "RoutingKernel.h" +#include "routing/RoutingKernel.h" #include "tensorrt_llm/common/config.h" #include "tensorrt_llm/common/cudaDriverWrapper.h" #include "tensorrt_llm/common/cudaUtils.h" @@ -76,12 +76,17 @@ enum class RoutingMethodType : int64_t DeepSeekV3 = 2, // Llama4: Top1 -> Sigmoid Llama4 = 3, - // RenormalizeNaive: Softmax -> TopK -> Renormalize + // RenormalizeNaive: Softmax -> TopK -> Renormalize. + // Mathematically equivalent to Renormalize (TopK -> Softmax), but conceptually applies + // softmax over all N experts first. At runtime, we use the Renormalize kernel path + // (TopK -> Softmax over K) which is faster since softmax is only over K selected experts. RenormalizeNaive = 4, // MiniMaxM2: Sigmoid -> RoutingBiasAdd -> TopK -> Renormalize(without bias) MiniMax2 = 5, + // SigmoidRenorm: Sigmoid -> TopK -> Renormalize + SigmoidRenorm = 6, // Unspecified - Unspecified = 6, + Unspecified = 7, }; inline int32_t maybeGetMinTokenCount(int32_t numPaddedTokens, int32_t hiddenSize, int32_t dtypeSizeBits) @@ -101,45 +106,52 @@ inline std::string serializeMoeRoutingMethodType(RoutingMethodType routingMethod case RoutingMethodType::Llama4: return "Llama4"; case RoutingMethodType::RenormalizeNaive: return "RenormalizeNaive"; case RoutingMethodType::MiniMax2: return "MiniMax2"; + case RoutingMethodType::SigmoidRenorm: return "SigmoidRenorm"; default: TLLM_CHECK_WITH_INFO(false, "Invalid routing method"); return ""; }; } -inline int32_t getMaxNumCtasInBatchDim(int32_t numTokens, int32_t topK, int32_t numExperts, int32_t tileTokensDim) +inline int32_t getMaxNumCgasInBatchDim(int32_t numTokens, int32_t topK, int32_t numExperts, int32_t cgaTileTokensDim) { - // For MoE, mNumTokens != 0 and the number of CTAs is known only at runtime. - // We launch maximally possible number of CTAs and use ptrNumNonExitingCtas to determine - // the actual number of CTAs to run. + // For MoE, mNumTokens != 0 and the number of CGAs is known only at runtime. + // We launch maximally possible number of CGAs and use ptrNumNonExitingCtas to determine + // the actual number of CGAs to run. // Initialize number of tokens with the number of expanded tokens after routing. - int32_t numRemainingTokens = numTokens * topK; - int32_t maxNumCtasInBatchDim = 0; - // First, distribute one token each expert until token depletion to maximize CTA tile count. - int32_t numExpertsFilled = std::min(numExperts, numRemainingTokens); - maxNumCtasInBatchDim += numExpertsFilled; + auto numRemainingTokens = numTokens * topK; + int32_t maxNumCgasInBatchDim = 0; + // First, distribute one token each expert until token depletion to maximize CGA tile count. + auto numExpertsFilled = std::min(numExperts, numRemainingTokens); + maxNumCgasInBatchDim += numExpertsFilled; numRemainingTokens -= numExpertsFilled; - // Next, greedily pour all remaining tokens to one expert to maximize CTA tile count. + // Next, greedily pour all remaining tokens to one expert to maximize CGA tile count. // E.g., at this point tokens over 4 experts are [1, 1, 1, 1], and we have 4 tokens left. - // If each CTA handles 4 tokens/expert, the greedy strategy is to pour all remaining tokens - // to any one expert to get to the 5th CTA tile. Otherwise, we can only get 4 tiles in total. + // If each CGA handles 4 tokens/expert, the greedy strategy is to pour all remaining tokens + // to any one expert to get to the 5th CGA tile. Otherwise, we can only get 4 tiles in total. // // Another way to reason about this is to pour the remaining tokens into buckets of some fixed // capacity. These buckets, if full, can then be attributed to any expert; it does not have to // belong to the same expert every time. if (numRemainingTokens > 0) { - // For every tileTokenDim tokens, we add an extra CTA tile in the token dimension. - // The number of CTA tiles is given by divDown(numRemainingTokens, tokenTileDim). - maxNumCtasInBatchDim += (numRemainingTokens / tileTokensDim); + // For every tileTokenDim tokens, we add an extra CGA tile in the token dimension. + // The number of CGA tiles is given by divDown(numRemainingTokens, tokenTileDim). + maxNumCgasInBatchDim += (numRemainingTokens / cgaTileTokensDim); } - return maxNumCtasInBatchDim; + return maxNumCgasInBatchDim; +} + +// Backward-compatible alias — callers outside routing may still use the old name. +inline int32_t getMaxNumCtasInBatchDim(int32_t numTokens, int32_t topK, int32_t numExperts, int32_t tileTokensDim) +{ + return getMaxNumCgasInBatchDim(numTokens, topK, numExperts, tileTokensDim); } inline int32_t getMaxPermutedPaddedCount( int32_t numTokens, int32_t expertsPerToken, int32_t numExperts, int32_t padding) { - int32_t maxCtas = getMaxNumCtasInBatchDim(numTokens, expertsPerToken, numExperts, padding); - return maxCtas * padding; + int32_t maxCgas = getMaxNumCgasInBatchDim(numTokens, expertsPerToken, numExperts, padding); + return maxCgas * padding; } class Runner @@ -147,7 +159,7 @@ class Runner public: explicit Runner(); - explicit Runner(int32_t tileTokensDim); + explicit Runner(int32_t tileTokensDim, int32_t clusterSizeInBatchDim = 1); void run(void* routingLogits, void* routingBias, int32_t numTokens, int32_t numExperts, int32_t topK, int32_t nGroups, int32_t topkGroups, int32_t localExpertOffset, int32_t localNumExperts, @@ -156,7 +168,8 @@ class Runner int32_t* permutedIdxToTokenIdx, void* expertWeights, int32_t* expertIds, int32_t* numTokensPerExpert, int32_t* ctaIdxXyToBatchIdx, int32_t* ctaIdxXyToMnLimit, int32_t* numNonExitingCtas, batchedGemm::trtllm::gen::Dtype dtypeElt, bool useRoutingScalesOnInput, bool useDeepSeekFp8, - RoutingMethodType routingMethodType, cudaStream_t stream); + RoutingMethodType routingMethodType, cudaStream_t stream, + batchedGemm::trtllm::gen::Dtype dtypeRoutingLogits = batchedGemm::trtllm::gen::Dtype::Bfloat16); private: int32_t mTileTokensDim; diff --git a/cpp/tensorrt_llm/thop/cuteDslMoeUtilsOp.cpp b/cpp/tensorrt_llm/thop/cuteDslMoeUtilsOp.cpp index 5f17e2372b6..cb6765ac6a5 100644 --- a/cpp/tensorrt_llm/thop/cuteDslMoeUtilsOp.cpp +++ b/cpp/tensorrt_llm/thop/cuteDslMoeUtilsOp.cpp @@ -20,6 +20,8 @@ #include +namespace btg = batchedGemm::trtllm::gen; + TRTLLM_NAMESPACE_BEGIN namespace torch_ext @@ -74,6 +76,9 @@ std::vector moe_topk_sort_impl(torch::optional con tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::Routing::Runner routing_runner(tile_tokens_dim); auto const& stream = at::cuda::getCurrentCUDAStream( routing_logits.has_value() ? routing_logits->get_device() : token_selected_experts->get_device()); + auto const dtypeRoutingLogits = routing_logits.has_value() + ? (routing_logits->scalar_type() == at::ScalarType::Float ? btg::Dtype::Fp32 : btg::Dtype::Bfloat16) + : btg::Dtype::Bfloat16; routing_runner.run(routing_logits_ptr, routing_bias_ptr, num_tokens, num_experts, top_k, n_group.value_or(0), topk_group.value_or(0), local_expert_offset, local_num_experts, routed_scaling_factor.value_or(1.0), expert_indexes.data_ptr(), expert_count_histogram.data_ptr(), total_num_padded_tokens.data_ptr(), @@ -82,7 +87,7 @@ std::vector moe_topk_sort_impl(torch::optional con num_tokens_per_expert.data_ptr(), tile_idx_to_expert_idx.data_ptr(), tile_idx_to_mn_limit.data_ptr(), num_non_exiting_tiles.data_ptr(), batchedGemm::trtllm::gen::Dtype::Void /* dtypeElt */, false /* use_routing_scales_on_input */, - false /* use_deep_seek_fp8 */, routing_method_type, stream); + false /* use_deep_seek_fp8 */, routing_method_type, stream, dtypeRoutingLogits); std::vector results{tile_idx_to_expert_idx, tile_idx_to_mn_limit, expanded_idx_to_permuted_idx, permuted_idx_to_expanded_idx, total_num_padded_tokens, num_non_exiting_tiles}; diff --git a/cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp b/cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp index 3440a59737f..4e270ec109a 100644 --- a/cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp +++ b/cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp @@ -68,15 +68,9 @@ std::vector run_fp4_block_scale_moe_runner(torch::optional(routing_method_type) == RoutingMethodType::DeepSeekV3) - { - TORCH_CHECK(routing_logits.value().scalar_type() == at::ScalarType::Float, "routing_logits must be float"); - } - else - { - TORCH_CHECK( - routing_logits.value().scalar_type() == at::ScalarType::BFloat16, "routing_logits must be bfloat16"); - } + TORCH_CHECK(routing_logits.value().scalar_type() == at::ScalarType::BFloat16 + || routing_logits.value().scalar_type() == at::ScalarType::Float, + "routing_logits must be bfloat16 or float32"); TORCH_CHECK(routing_logits.value().dim() == 2, "routing_logits must be 2D."); TORCH_CHECK(routing_logits.value().sizes()[1] == num_experts, "routing_logits has incorrect shape."); } @@ -264,6 +258,9 @@ std::vector run_fp4_block_scale_moe_runner(torch::optional(), expert_count_histogram.data_ptr(), total_num_padded_tokens.data_ptr(), @@ -272,7 +269,7 @@ std::vector run_fp4_block_scale_moe_runner(torch::optional(), cta_idx_xy_to_batch_idx.data_ptr(), cta_idx_xy_to_mn_limit.data_ptr(), num_non_exiting_ctas.data_ptr(), args.mDtypeElt, false /* use_routing_scales_on_input */, false /* use_deep_seek_fp8 */, - static_cast(routing_method_type), stream); + static_cast(routing_method_type), stream, dtypeRoutingLogits); // // FC13 (gemm1) + FC2 (gemm2) diff --git a/cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp b/cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp index 3c13c695991..a0b857b24f6 100644 --- a/cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp +++ b/cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp @@ -63,15 +63,9 @@ at::Tensor run_fp8_block_scale_moe(at::optional const& routing_logit } else if (routing_logits.has_value()) { - if (static_cast(routing_method_type) == RoutingMethodType::DeepSeekV3) - { - TORCH_CHECK(routing_logits.value().scalar_type() == at::ScalarType::Float, "routing_logits must be float"); - } - else - { - TORCH_CHECK( - routing_logits.value().scalar_type() == at::ScalarType::BFloat16, "routing_logits must be bfloat16"); - } + TORCH_CHECK(routing_logits.value().scalar_type() == at::ScalarType::BFloat16 + || routing_logits.value().scalar_type() == at::ScalarType::Float, + "routing_logits must be bfloat16 or float32"); TORCH_CHECK(routing_logits.value().dim() == 2, "routing_logits must be 2D."); TORCH_CHECK(routing_logits.value().sizes()[1] == num_experts, "routing_logits dim1 must match num_experts."); } @@ -232,6 +226,9 @@ at::Tensor run_fp8_block_scale_moe(at::optional const& routing_logit tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::Routing::Runner routing_runner(tile_tokens_dim); auto const& stream = at::cuda::getCurrentCUDAStream( routing_logits.has_value() ? routing_logits.value().get_device() : topk_ids.value().get_device()); + auto const dtypeRoutingLogits = routing_logits.has_value() + ? (routing_logits.value().scalar_type() == at::ScalarType::Float ? btg::Dtype::Fp32 : btg::Dtype::Bfloat16) + : btg::Dtype::Bfloat16; routing_runner.run(args.routing_logits, args.routing_bias, args.num_tokens, args.num_experts, args.top_k, args.n_group, args.topk_group, args.local_expert_offset, args.local_num_experts, args.routed_scaling_factor, expert_indexes.data_ptr(), expert_count_histogram.data_ptr(), total_num_padded_tokens.data_ptr(), @@ -239,7 +236,7 @@ at::Tensor run_fp8_block_scale_moe(at::optional const& routing_logit permuted_idx_to_token_idx.data_ptr(), expert_weights_ptr, args.topk_ids, num_tokens_per_expert.data_ptr(), cta_idx_xy_to_batch_idx.data_ptr(), cta_idx_xy_to_mn_limit.data_ptr(), num_non_exiting_ctas.data_ptr(), args.mDtypeElt, false, true, - static_cast(routing_method_type), stream); + static_cast(routing_method_type), stream, dtypeRoutingLogits); // MoE kernel except routing TORCH_CHECK(hidden_states.scalar_type() == at::ScalarType::Float8_e4m3fn, "hidden_states must be fp8."); diff --git a/cpp/tensorrt_llm/thop/fp8PerTensorScaleMoe.cpp b/cpp/tensorrt_llm/thop/fp8PerTensorScaleMoe.cpp index 092f8f01362..183a9172169 100644 --- a/cpp/tensorrt_llm/thop/fp8PerTensorScaleMoe.cpp +++ b/cpp/tensorrt_llm/thop/fp8PerTensorScaleMoe.cpp @@ -57,24 +57,9 @@ torch::Tensor fp8_per_tensor_scale_moe_runner(torch::optional con } else if (routing_logits.has_value()) { - if (use_routing_scales_on_input) - { - TORCH_CHECK( - routing_logits.value().scalar_type() == at::ScalarType::BFloat16, "routing_logits must be bfloat16."); - } - else - { - if (static_cast(routing_method_type) == RoutingMethodType::DeepSeekV3) - { - TORCH_CHECK( - routing_logits.value().scalar_type() == at::ScalarType::Float, "routing_logits must be float"); - } - else - { - TORCH_CHECK(routing_logits.value().scalar_type() == at::ScalarType::BFloat16, - "routing_logits must be bfloat16"); - } - } + TORCH_CHECK(routing_logits.value().scalar_type() == at::ScalarType::BFloat16 + || routing_logits.value().scalar_type() == at::ScalarType::Float, + "routing_logits must be bfloat16 or float32"); TORCH_CHECK(routing_logits.value().dim() == 2, "routing_logits must be 2D."); TORCH_CHECK(routing_logits.value().sizes()[1] == num_experts, "routing_logits has incorrect shape."); } @@ -230,6 +215,9 @@ torch::Tensor fp8_per_tensor_scale_moe_runner(torch::optional con tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::Routing::Runner routing_runner(tile_tokens_dim); auto const& stream = at::cuda::getCurrentCUDAStream( routing_logits.has_value() ? routing_logits.value().get_device() : topk_ids.value().get_device()); + auto const dtypeRoutingLogits = routing_logits.has_value() + ? (routing_logits.value().scalar_type() == at::ScalarType::Float ? btg::Dtype::Fp32 : btg::Dtype::Bfloat16) + : btg::Dtype::Bfloat16; routing_runner.run(args.routing_logits, args.routing_bias, args.num_tokens, args.num_experts, args.top_k, args.n_group, args.topk_group, args.local_expert_offset, args.local_num_experts, args.routed_scaling_factor, expert_indexes.data_ptr(), expert_count_histogram.data_ptr(), total_num_padded_tokens.data_ptr(), @@ -238,7 +226,7 @@ torch::Tensor fp8_per_tensor_scale_moe_runner(torch::optional con num_tokens_per_expert.data_ptr(), cta_idx_xy_to_batch_idx.data_ptr(), cta_idx_xy_to_mn_limit.data_ptr(), num_non_exiting_ctas.data_ptr(), args.mDtypeElt, use_routing_scales_on_input, false /* use_deep_seek_fp8 */, static_cast(routing_method_type), - stream); + stream, dtypeRoutingLogits); // MoE kernel except routing TORCH_CHECK(hidden_states.scalar_type() == at::ScalarType::Float8_e4m3fn, "hidden_states must be fp8."); diff --git a/cpp/tensorrt_llm/thop/mxFp4BlockScaleMoe.cpp b/cpp/tensorrt_llm/thop/mxFp4BlockScaleMoe.cpp index 5e8331b77c3..41788fc3a84 100644 --- a/cpp/tensorrt_llm/thop/mxFp4BlockScaleMoe.cpp +++ b/cpp/tensorrt_llm/thop/mxFp4BlockScaleMoe.cpp @@ -72,15 +72,9 @@ torch::Tensor dtype_mxe2m1_block_scale_moe_runner(torch::optional } else if (routing_logits.has_value()) { - if (static_cast(routing_method_type) == RoutingMethodType::DeepSeekV3) - { - TORCH_CHECK(routing_logits.value().scalar_type() == at::ScalarType::Float, "routing_logits must be float"); - } - else - { - TORCH_CHECK( - routing_logits.value().scalar_type() == at::ScalarType::BFloat16, "routing_logits must be bfloat16"); - } + TORCH_CHECK(routing_logits.value().scalar_type() == at::ScalarType::BFloat16 + || routing_logits.value().scalar_type() == at::ScalarType::Float, + "routing_logits must be bfloat16 or float32"); TORCH_CHECK(routing_logits.value().dim() == 2, "routing_logits must be 2D."); TORCH_CHECK(routing_logits.value().sizes()[1] == num_experts, "routing_logits dim1 must match num_experts."); @@ -274,6 +268,9 @@ torch::Tensor dtype_mxe2m1_block_scale_moe_runner(torch::optional tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::Routing::Runner routing_runner(tile_tokens_dim); auto const& stream = at::cuda::getCurrentCUDAStream( routing_logits.has_value() ? routing_logits.value().get_device() : topk_ids.value().get_device()); + auto const dtypeRoutingLogits = routing_logits.has_value() + ? (routing_logits.value().scalar_type() == at::ScalarType::Float ? btg::Dtype::Fp32 : btg::Dtype::Bfloat16) + : btg::Dtype::Bfloat16; routing_runner.run(args.routing_logits, args.routing_bias, args.num_tokens, args.num_experts, args.top_k, args.n_group, args.topk_group, args.local_expert_offset, args.local_num_experts, args.routed_scaling_factor, expert_indexes.data_ptr(), expert_count_histogram.data_ptr(), total_num_padded_tokens.data_ptr(), @@ -282,7 +279,7 @@ torch::Tensor dtype_mxe2m1_block_scale_moe_runner(torch::optional num_tokens_per_expert.data_ptr(), cta_idx_xy_to_batch_idx.data_ptr(), cta_idx_xy_to_mn_limit.data_ptr(), num_non_exiting_ctas.data_ptr(), args.mDtypeElt, false /* use_routing_scales_on_input */, false /* use_deep_seek_fp8 */, - static_cast(routing_method_type), stream); + static_cast(routing_method_type), stream, dtypeRoutingLogits); // // FC13 (gemm1) + FC2 (gemm2) diff --git a/cpp/tests/unit_tests/kernels/CMakeLists.txt b/cpp/tests/unit_tests/kernels/CMakeLists.txt index ab8280498e5..95d33e42105 100644 --- a/cpp/tests/unit_tests/kernels/CMakeLists.txt +++ b/cpp/tests/unit_tests/kernels/CMakeLists.txt @@ -89,7 +89,7 @@ add_gtest(samplingKernelsTest "${SAMPLING_KERNEL_TEST_SRC}") set(ROUTING_KERNEL_TEST_SRC routing/routingTest.cpp routing/routingLlama4Test.cpp - routing/routingRenormalizeTest.cpp routing/routingDeepSeekTest.cpp) + routing/routingCustomTest.cpp routing/routingDeepSeekTest.cpp) add_gtest(routingKernelsTest "${ROUTING_KERNEL_TEST_SRC}") target_link_libraries(routingKernelsTest PRIVATE Python3::Python) diff --git a/cpp/tests/unit_tests/kernels/routing/routingCustomTest.cpp b/cpp/tests/unit_tests/kernels/routing/routingCustomTest.cpp new file mode 100644 index 00000000000..4aa7e0762d8 --- /dev/null +++ b/cpp/tests/unit_tests/kernels/routing/routingCustomTest.cpp @@ -0,0 +1,1549 @@ +/* + * Copyright (c) 2022-2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tests/unit_tests/kernels/routing/routingTest.h" + +#include + +namespace tk = tensorrt_llm::kernels; +namespace btg = batchedGemm::trtllm::gen; +using namespace tensorrt_llm::runtime; +using namespace tensorrt_llm::tests::kernels::routing; + +namespace +{ + +template +class RoutingCustomKernelTest : public RoutingKernelTest +{ + +protected: + using RoutingKernelTest::mSeed; + using RoutingKernelTest::mStream; + using RoutingKernelTest::mBufferManager; + using typename RoutingKernelTest::PackedType; + +private: + // Routing bias buffers (used by SigmoidBias preprocess) + TensorPtr mPtrRoutingBiasHost; + TensorPtr mPtrRoutingBiasDevice; + + static float sigmoid_accurate(float x) + { + return 0.5f * std::tanh(0.5f * x) + 0.5f; + } + + // Reference implementation for all policy combinations: + // 1. Softmax + NoOp (Default: softmax before topK, raw scores) + // 2. NoOp + Softmax (Renormalize: topK first, then softmax) + // 3. Softmax + SumNormalize (RenormalizeNaive path) + // 4. SigmoidBias + ScaledSumNormalize (DeepSeek-style path) + // 5. Sigmoid + SumNormalize (SigmoidRenorm path) + // 6. NoOp + NoOp (raw topK, no transformation) + void computeTopKExperts(RoutingKernelTestParam const& param) override + { + for (int it = 0; it < param.numTokens; ++it) + { + std::vector expWeightsIdx(param.numExperts); + std::vector expIdx(param.topK); + + // Per-expert sigmoid scores — only populated for SigmoidBias preprocess. + std::vector sigmoidScores(param.numExperts, 0.f); + + // --- Read raw scores and apply preprocess --- + for (int ie = 0; ie < param.numExperts; ++ie) + { + float score = static_cast(bufferCast(*this->mPtrScoresHost)[it * param.numExperts + ie]); + + if (param.preprocessType == RoutingPreprocessType::Sigmoid) + { + float sig = sigmoid_accurate(score); + score = ie < param.numExperts ? sig : -std::numeric_limits::infinity(); + } + else if (param.preprocessType == RoutingPreprocessType::SigmoidBias) + { + float sig = sigmoid_accurate(score); + sigmoidScores[ie] = sig; + float bias = static_cast(bufferCast(*mPtrRoutingBiasHost)[ie]); + score = sig + bias; + } + + expWeightsIdx[ie] = PackedFloat{score, static_cast(ie)}; + } + + // Apply softmax preprocess (over all experts) when requested + if (param.preprocessType == RoutingPreprocessType::Softmax) + { + float maxScore = -std::numeric_limits::infinity(); + for (int ie = 0; ie < param.numExperts; ++ie) + { + maxScore = std::max(maxScore, expWeightsIdx[ie].score); + } + float sum = 0.f; + for (int ie = 0; ie < param.numExperts; ++ie) + { + expWeightsIdx[ie].score = std::exp(expWeightsIdx[ie].score - maxScore); + sum += expWeightsIdx[ie].score; + } + for (int ie = 0; ie < param.numExperts; ++ie) + { + expWeightsIdx[ie].score /= sum; + } + } + + // --- TopK selection --- + std::partial_sort_copy(expWeightsIdx.begin(), expWeightsIdx.end(), expIdx.begin(), expIdx.end(), comp); + + // --- Apply postprocess --- + if (param.postprocessType == RoutingPostprocessType::Softmax) + { + // Softmax over top-K scores + float maxScore = -std::numeric_limits::infinity(); + for (uint32_t ie = 0; ie < param.topK; ++ie) + { + maxScore = std::max(maxScore, expIdx[ie].score); + } + float sum = 0.f; + for (uint32_t ie = 0; ie < param.topK; ++ie) + { + sum += std::exp(expIdx[ie].score - maxScore); + } + for (uint32_t ie = 0; ie < param.topK; ++ie) + { + expIdx[ie].score = std::exp(expIdx[ie].score - maxScore) / sum; + } + } + else if (param.postprocessType == RoutingPostprocessType::SumNormalize) + { + // SumNormalize: divide top-K scores by their sum + if (param.normTopkProb) + { + float sum = 0.f; + for (uint32_t ie = 0; ie < param.topK; ++ie) + { + sum += expIdx[ie].score; + } + for (uint32_t ie = 0; ie < param.topK; ++ie) + { + expIdx[ie].score /= sum; + } + } + } + else if (param.postprocessType == RoutingPostprocessType::ScaledSumNormalize) + { + // Recover sigmoid scores, renormalize by their sum, and scale + float sumSigmoid = 0.f; + for (uint32_t ie = 0; ie < param.topK; ++ie) + { + sumSigmoid += sigmoidScores[expIdx[ie].idx]; + } + for (uint32_t ie = 0; ie < param.topK; ++ie) + { + expIdx[ie].score = sigmoidScores[expIdx[ie].idx] * param.routedScalingFactor / sumSigmoid; + } + } + // For NoOp postprocess: scores are left unchanged. + + // --- Store results --- + for (uint32_t ie = 0; ie < param.topK; ++ie) + { + // Set invalid topk indices for the first half of the topk + if (param.hasInvalidTopKInput && ie < param.topK / 2 + 1) + { + expIdx[ie].idx = static_cast(param.invalidExpertIdValue); + } + + PackedType si{static_cast(expIdx[ie].score), expIdx[ie].idx}; + reinterpret_cast(bufferCast(*this->mPtrTopKPackedHost))[it * param.topK + ie] = si; + if (param.useTopKAsInput) + { + bufferCast(*this->mPtrTopKIdsHost)[it * param.topK + ie] + = static_cast(expIdx[ie].idx); + bufferCast(*this->mPtrTopKWeightsHost)[it * param.topK + ie] = static_cast(expIdx[ie].score); + } + else if (param.getExpWeights) + { + bufferCast(*this->mPtrTopKWeightsHost)[it * param.topK + ie] = static_cast(expIdx[ie].score); + } + } + } + } + +protected: + void allocateBuffers(RoutingKernelTestParam const& param) override + { + RoutingKernelTest::allocateBuffers(param); + int64_t scoresSize = param.numTokens * param.numExperts; + this->mPtrScoresHost = mBufferManager->pinned(ITensor::makeShape({scoresSize}), TRTDataType::value); + this->mPtrScoresDevice = mBufferManager->gpu(ITensor::makeShape({scoresSize}), TRTDataType::value); + + // Allocate routing bias buffers when needed + if (param.preprocessType == RoutingPreprocessType::SigmoidBias) + { + mPtrRoutingBiasHost = mBufferManager->pinned(ITensor::makeShape({param.numExperts}), TRTDataType::value); + mPtrRoutingBiasDevice = mBufferManager->gpu(ITensor::makeShape({param.numExperts}), TRTDataType::value); + } + } + + void setupBuffers(RoutingKernelTestParam const& param) override + { + RoutingKernelTest::setupBuffers(param); + + // Initialize routing bias with small random values + if (param.preprocessType == RoutingPreprocessType::SigmoidBias) + { + T* biasPtr = bufferCast(*mPtrRoutingBiasHost); + initData(biasPtr, param.numExperts, mSeed + 7); + mBufferManager->copy(*mPtrRoutingBiasHost, *mPtrRoutingBiasDevice); + } + } + + template + void setParams(RoutingKernelTestParam const& param, RoutingData& routingData) + { + RoutingKernelTest::setCommonParams(param, routingData); + + if (sizeof(T) == 4) + { + routingData.mDtypeOutput = btg::Dtype::Fp32; + routingData.mDtypeInput = btg::Dtype::Fp32; + } + else + { + routingData.mDtypeOutput = btg::Dtype::Bfloat16; + routingData.mDtypeInput = btg::Dtype::Bfloat16; + } + + // Set policy types from test param (already derived by build()) + routingData.mPreprocessType = param.preprocessType; + routingData.mPostprocessType = param.postprocessType; + routingData.mNormTopkProb = param.normTopkProb; + + // Set routing bias and scale when using SigmoidBias preprocess + if (param.preprocessType == RoutingPreprocessType::SigmoidBias) + { + routingData.mPtrRoutingBias = bufferCast(*mPtrRoutingBiasDevice); + // Bias dtype matches T (the test's type parameter) + routingData.mDtypeBias = (sizeof(T) == 4) ? btg::Dtype::Fp32 : btg::Dtype::Bfloat16; + routingData.mRouteScale = param.routedScalingFactor; + } + + if (param.useTopKAsInput) + { + routingData.mPtrTopKIds = bufferCast(*this->mPtrTopKIdsDevice); + routingData.mPtrScores = nullptr; + } + else if (param.useTopKPackedAsInput) + { + // mPtrTopKPacked is already set by setCommonParams; just clear scores and topKIds + routingData.mPtrTopKIds = nullptr; + routingData.mPtrScores = nullptr; + } + else + { + routingData.mPtrTopKIds = nullptr; + routingData.mPtrScores = bufferCast(*this->mPtrScoresDevice); + } + } + + void callTestedFunction( + RoutingKernelTestParam const& param, tensorrt_llm::runtime::ITensor::SharedPtr& workspaceDevice) override + { + moe::dev::routing::routingCustom::Data routingData; + setParams(param, routingData); + moe::dev::routing::routingCustom::run(routingData, mStream->get()); + } +}; + +TYPED_TEST_SUITE(RoutingCustomKernelTest, FloatAndBf16Types); + +TYPED_TEST(RoutingCustomKernelTest, BlockLevelParallelization) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(4) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, BlockLevelParallelizationWithExpertParallelization) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(4) + .withNumExperts(128) + .withTopK(8) + .withExpertParallelization(2, 1) + .withTileTokensDim(192) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, BlockLevelParallelizationWithInvalidTopKInput) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(4) + .withNumExperts(128) + .withTopK(8) + .withExpertParallelization(2, 1) + .withTileTokensDim(192) + .withUseTopKAsInput(true) + .withHasInvalidTopKInput(true) + .build(); + this->runTest(param); +}; + +// --- Tests for useTopKAsInput (mPtrTopKIds + mPtrTopKWeights as input) --- +// These test the runPostTopKPipeline path at block, cluster, and coop levels. + +TYPED_TEST(RoutingCustomKernelTest, BlockLevelTopKAsInput) +{ + // Small token count -> single-block path in runPostTopKPipeline + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(4) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .withUseTopKAsInput(true) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, ClusterLevelTopKAsInput) +{ + // Medium token count -> single-cluster path in runPostTopKPipeline + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(100) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .withUseTopKAsInput(true) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, CoopLevelTopKAsInput) +{ + // Large token count -> coop path in runPostTopKPipeline + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(1000) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .withUseTopKAsInput(true) + .withRequiredComputeCapability(8) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, ClusterLevelParallelization) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(10) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(192) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, ClusterLevelParallelizationWithExpertParallelization) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(100) + .withNumExperts(128) + .withTopK(8) + .withExpertParallelization(2, 1) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, ClusterLevelParallelizationWithInvalidTopKInput) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(100) + .withNumExperts(128) + .withTopK(8) + .withExpertParallelization(2, 1) + .withTileTokensDim(256) + .withUseTopKAsInput(true) + .withHasInvalidTopKInput(true) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, ClusterLevelParallelizationWithRenormalizeNaive) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::RenormalizeNaive) + .withNumTokens(10) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +// --- Tests for useTopKPackedAsInput (mPtrTopKPacked without mPtrScores) --- +// These test the runPostTopKPipeline path for the packed input format. + +TYPED_TEST(RoutingCustomKernelTest, BlockLevelTopKPackedAsInput) +{ + // Small token count -> single-block path in runPostTopKPipeline + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(4) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .withUseTopKPackedAsInput(true) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, ClusterLevelTopKPackedAsInput) +{ + // Medium token count -> single-cluster path in runPostTopKPipeline + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(100) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .withUseTopKPackedAsInput(true) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, DeviceLevelTopKPackedAsInput) +{ + // Large token count -> coop or multi-kernel path in runPostTopKPipeline + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(1000) + .withNumExperts(128) + .withTopK(4) + .withTileTokensDim(256) + .withUseTopKPackedAsInput(true) + .withRequiredComputeCapability(8) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, DeviceLevelParallelization) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(1000) + .withNumExperts(512) + .withTopK(10) + .withTileTokensDim(8) + .withRequiredComputeCapability(8) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, ClusterLevelParallelizationTop4) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(200) + .withNumExperts(128) + .withTopK(4) + .withTileTokensDim(8) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, ClusterLevelParallelizationWithInvalidTopKInputTop4) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(200) + .withNumExperts(128) + .withTopK(4) + .withTileTokensDim(8) + .withUseTopKAsInput(true) + .withHasInvalidTopKInput(true) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, ClusterLevelParallelizationWithExpertParallelizationTop4) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(256) + .withNumExperts(128) + .withTopK(4) + .withExpertParallelization(2, 1) + .withTileTokensDim(8) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, ClusterLevelParallelizationWithRenormalizeNaiveTop4) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::RenormalizeNaive) + .withNumTokens(10) + .withNumExperts(128) + .withTopK(4) + .withTileTokensDim(8) + .build(); + this->runTest(param); +}; + +// --- Tests for Default (Softmax + NoOp postprocess) --- + +TYPED_TEST(RoutingCustomKernelTest, DefaultBlockLevel) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Default) + .withNumTokens(4) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, DefaultClusterLevel) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Default) + .withNumTokens(10) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, DefaultDeviceLevel) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Default) + .withNumTokens(1000) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .withRequiredComputeCapability(8) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, DefaultWithExpertParallelization) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Default) + .withNumTokens(10) + .withNumExperts(128) + .withTopK(8) + .withExpertParallelization(2, 1) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +// --- Tests for RenormalizeNaive at block and device levels --- + +TYPED_TEST(RoutingCustomKernelTest, RenormalizeNaiveBlockLevel) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::RenormalizeNaive) + .withNumTokens(4) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, RenormalizeNaiveDeviceLevel) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::RenormalizeNaive) + .withNumTokens(1000) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .withRequiredComputeCapability(8) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, RenormalizeNaiveWithExpertParallelization) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::RenormalizeNaive) + .withNumTokens(10) + .withNumExperts(128) + .withTopK(8) + .withExpertParallelization(2, 1) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, DeviceLevelParallelizationTop4) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(1000) + .withNumExperts(128) + .withTopK(4) + .withTileTokensDim(256) + .withUseTopKAsInput(true) + .withHasInvalidTopKInput(true) + .withRequiredComputeCapability(8) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, BlockLevelParallelizationLargeN) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(4) + .withNumExperts(2048) + .withTopK(32) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, ClusterLevelParallelizationLargeN) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(100) + .withNumExperts(2048) + .withTopK(32) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, DeviceLevelParallelizationLargeN) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(1000) + .withNumExperts(2048) + .withTopK(32) + .withTileTokensDim(256) + .withRequiredComputeCapability(8) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, DeviceLevelParallelizationLargeNWithInvalidTopKInput) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(1000) + .withNumExperts(2048) + .withTopK(32) + .withTileTokensDim(256) + .withUseTopKAsInput(true) + .withHasInvalidTopKInput(true) + .withRequiredComputeCapability(8) + .build(); + this->runTest(param); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Tests for invalid expertId = numExperts (instead of -1). +// Some frameworks use expertId == numExperts to mark unassigned slots. +// The kernel must handle this without illegal memory access. +// These tests exercise the block, cluster, and device-level paths with topKIds input +// where some expert IDs are set to numExperts. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TYPED_TEST(RoutingCustomKernelTest, BlockLevelInvalidExpertIdEqualsNumExperts) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(4) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .withUseTopKAsInput(true) + .withHasInvalidTopKInput(true) + .withInvalidExpertIdValue(128) // numExperts as invalid marker + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, ClusterLevelInvalidExpertIdEqualsNumExperts) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(100) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .withUseTopKAsInput(true) + .withHasInvalidTopKInput(true) + .withInvalidExpertIdValue(128) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, DeviceLevelInvalidExpertIdEqualsNumExperts) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(1000) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .withUseTopKAsInput(true) + .withHasInvalidTopKInput(true) + .withInvalidExpertIdValue(128) + .withRequiredComputeCapability(8) + .build(); + this->runTest(param); +}; + +// Test with numExperts < MaxNumExperts (fall-through tier) — expertId=numExperts +// passes the `< MaxNumExperts` check but should still be treated as invalid. +TYPED_TEST(RoutingCustomKernelTest, BlockLevelInvalidExpertIdFallThroughTier) +{ + // numExperts=100 → dispatches to E128 tier (MaxNumExperts=128). + // expertId=100 passes `100 < 128` but is invalid (only 0..99 are valid). + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(4) + .withNumExperts(100) + .withTopK(8) + .withTileTokensDim(256) + .withUseTopKAsInput(true) + .withHasInvalidTopKInput(true) + .withInvalidExpertIdValue(100) // numExperts as invalid marker + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, ClusterLevelInvalidExpertIdFallThroughTier) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(100) + .withNumExperts(100) + .withTopK(8) + .withTileTokensDim(256) + .withUseTopKAsInput(true) + .withHasInvalidTopKInput(true) + .withInvalidExpertIdValue(100) + .build(); + this->runTest(param); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Tests for Renormalize with new expert/topK tiers (E160, E576, K22) +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// --- E576 experts, K22 topK (exercises the new E576 and K22 tiers) --- + +TYPED_TEST(RoutingCustomKernelTest, RenormalizeBlockLevelE576K22) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(4) + .withNumExperts(576) + .withTopK(22) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, RenormalizeClusterLevelE576K22) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(100) + .withNumExperts(576) + .withTopK(22) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, RenormalizeDeviceLevelE576K22) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(1000) + .withNumExperts(576) + .withTopK(22) + .withTileTokensDim(256) + .withRequiredComputeCapability(8) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, RenormalizeDeviceLevelE576K22TopKAsInput) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(1000) + .withNumExperts(576) + .withTopK(22) + .withTileTokensDim(256) + .withUseTopKAsInput(true) + .withRequiredComputeCapability(8) + .build(); + this->runTest(param); +}; + +// --- E160 experts, K8 topK (exercises the new E160 tier) --- + +TYPED_TEST(RoutingCustomKernelTest, RenormalizeBlockLevelE160) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(4) + .withNumExperts(160) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, RenormalizeClusterLevelE160) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(100) + .withNumExperts(160) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, RenormalizeDeviceLevelE160) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(1000) + .withNumExperts(160) + .withTopK(8) + .withTileTokensDim(256) + .withRequiredComputeCapability(8) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, RenormalizeClusterLevelE160WithExpertParallelization) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(100) + .withNumExperts(160) + .withTopK(8) + .withExpertParallelization(2, 1) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Tests for 384-expert tier (Renormalize + SigmoidBias policies). +// 384 is in getMaxNumExperts() tiers but was previously missing from some PolicyTraits, +// causing thread-count mismatch bugs. These tests cover block, cluster, and device paths. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// --- Renormalize with E384 (exercises Tier<384,8> in None+Softmax policy) --- + +TYPED_TEST(RoutingCustomKernelTest, RenormalizeBlockLevelE384) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(1) + .withNumExperts(384) + .withTopK(8) + .withTileTokensDim(8) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, RenormalizeClusterLevelE384) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(100) + .withNumExperts(384) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, RenormalizeDeviceLevelE384) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(1000) + .withNumExperts(384) + .withTopK(8) + .withTileTokensDim(256) + .withRequiredComputeCapability(8) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, RenormalizeBlockLevelE384WithEP) +{ + // Mirrors the failing multi-GPU test: e384, topK=8, seq=1, EP=4 + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(1) + .withNumExperts(384) + .withTopK(8) + .withExpertParallelization(4, 1) + .withTileTokensDim(8) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, RenormalizeBlockLevelE384TopKAsInput) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(1) + .withNumExperts(384) + .withTopK(8) + .withTileTokensDim(8) + .withUseTopKAsInput(true) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, RenormalizeBlockLevelE384TopKAsInputWithEP) +{ + // Mirrors the failing multi-GPU test with pre-computed topK: e384, topK=8, seq=1, EP=4 + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(1) + .withNumExperts(384) + .withTopK(8) + .withExpertParallelization(4, 1) + .withTileTokensDim(8) + .withUseTopKAsInput(true) + .build(); + this->runTest(param); +}; + +// --- SigmoidBias with E384 (exercises Tier<384,8> in SigmoidBias+ScaledSumNormalize policy) --- + +TYPED_TEST(RoutingCustomKernelTest, SigmoidBiasBlockLevelE384) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withPreprocessType(RoutingPreprocessType::SigmoidBias) + .withPostprocessType(RoutingPostprocessType::ScaledSumNormalize) + .withRoutedScalingFactor(2.5f) + .withNumTokens(4) + .withNumExperts(384) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, SigmoidBiasClusterLevelE384) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withPreprocessType(RoutingPreprocessType::SigmoidBias) + .withPostprocessType(RoutingPostprocessType::ScaledSumNormalize) + .withRoutedScalingFactor(2.5f) + .withNumTokens(100) + .withNumExperts(384) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Tests for scores input + cooperative kernel path (scores→topK kernel + coop histogram+offsets). +// These verify the coop fast-path when input is raw mPtrScores (not pre-computed topK). +// Triggered when numTokens > cluster capacity (256) and within coop capacity. +// Requires SM90+ (coop kernel uses grid-sync). +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TYPED_TEST(RoutingCustomKernelTest, ScoresCoopRenormalize) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(1000) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, ScoresCoopRenormalizeE160) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(1000) + .withNumExperts(160) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, ScoresCoopRenormalizeE256K4) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(1000) + .withNumExperts(256) + .withTopK(4) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, ScoresCoopRenormalizeE576K22) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(500) + .withNumExperts(576) + .withTopK(22) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, ScoresCoopRenormalizeWithExpertParallelization) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(1000) + .withNumExperts(128) + .withTopK(8) + .withExpertParallelization(2, 1) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, ScoresCoopDefault) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Default) + .withNumTokens(1000) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, ScoresCoopRenormalizeNaive) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::RenormalizeNaive) + .withNumTokens(1000) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, ScoresCoopSigmoidBias) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withPreprocessType(RoutingPreprocessType::SigmoidBias) + .withPostprocessType(RoutingPostprocessType::ScaledSumNormalize) + .withRoutedScalingFactor(2.5f) + .withNumTokens(1000) + .withNumExperts(512) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Tests for SigmoidBias + ScaledSumNormalize (DeepSeek-style routing via routingCustom) +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SigmoidBias PolicyTraits: only E512 × K8. + +TYPED_TEST(RoutingCustomKernelTest, SigmoidBiasBlockLevel) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withPreprocessType(RoutingPreprocessType::SigmoidBias) + .withPostprocessType(RoutingPostprocessType::ScaledSumNormalize) + .withRoutedScalingFactor(2.5f) + .withNumTokens(4) + .withNumExperts(512) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, SigmoidBiasClusterLevel) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withPreprocessType(RoutingPreprocessType::SigmoidBias) + .withPostprocessType(RoutingPostprocessType::ScaledSumNormalize) + .withRoutedScalingFactor(2.5f) + .withNumTokens(100) + .withNumExperts(512) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, SigmoidBiasDeviceLevel) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withPreprocessType(RoutingPreprocessType::SigmoidBias) + .withPostprocessType(RoutingPostprocessType::ScaledSumNormalize) + .withRoutedScalingFactor(2.5f) + .withNumTokens(1000) + .withNumExperts(512) + .withTopK(8) + .withTileTokensDim(8) + .withRequiredComputeCapability(8) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, SigmoidBiasWithExpertParallelization) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withPreprocessType(RoutingPreprocessType::SigmoidBias) + .withPostprocessType(RoutingPostprocessType::ScaledSumNormalize) + .withRoutedScalingFactor(2.5f) + .withNumTokens(10) + .withNumExperts(512) + .withTopK(8) + .withExpertParallelization(2, 1) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Tests for MiniMax2 (SigmoidBias + ScaledSumNormalize with routeScale=1.0) +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MiniMax2 PolicyTraits: same as SigmoidBias, only E512 × K8. + +TYPED_TEST(RoutingCustomKernelTest, MiniMax2BlockLevel) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::MiniMax2) + .withNumTokens(4) + .withNumExperts(512) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, MiniMax2ClusterLevel) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::MiniMax2) + .withNumTokens(100) + .withNumExperts(512) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, MiniMax2DeviceLevel) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::MiniMax2) + .withNumTokens(1000) + .withNumExperts(512) + .withTopK(8) + .withTileTokensDim(8) + .withRequiredComputeCapability(8) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, MiniMax2WithExpertParallelization) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::MiniMax2) + .withNumTokens(10) + .withNumExperts(512) + .withTopK(8) + .withExpertParallelization(2, 1) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Tests for mixed input/bias dtypes (SigmoidBias with float32 scores + bfloat16 bias, and vice versa). +// These test the loadScalar + mDtypeBias dispatch for cross-dtype bias reading. +// The test allocates bias in the "opposite" dtype from T. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TYPED_TEST(RoutingCustomKernelTest, MiniMax2MixedBiasDtype) +{ + using OtherT = std::conditional_t, __nv_bfloat16, float>; + + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::MiniMax2) + .withNumTokens(100) + .withNumExperts(512) + .withTopK(8) + .withTileTokensDim(256) + .build(); + + // Allocate and setup normal buffers + this->allocateBuffers(param); + this->setupBuffers(param); + + // Allocate bias in the "opposite" dtype from T + auto otherBiasHost + = this->mBufferManager->pinned(ITensor::makeShape({param.numExperts}), TRTDataType::value); + auto otherBiasDevice + = this->mBufferManager->gpu(ITensor::makeShape({param.numExperts}), TRTDataType::value); + auto biasPtr = bufferCast(*otherBiasHost); + for (int i = 0; i < param.numExperts; i++) + { + biasPtr[i] = static_cast(0.01f * (i % 100)); + } + this->mBufferManager->copy(*otherBiasHost, *otherBiasDevice); + this->mStream->synchronize(); + + // Setup routing data with mixed dtypes + moe::dev::routing::routingCustom::Data routingData; + this->setCommonParams(param, routingData); + routingData.mDtypeOutput = (sizeof(TypeParam) == 4) ? btg::Dtype::Fp32 : btg::Dtype::Bfloat16; + routingData.mDtypeInput = routingData.mDtypeOutput; + routingData.mPreprocessType = param.preprocessType; + routingData.mPostprocessType = param.postprocessType; + routingData.mNormTopkProb = param.normTopkProb; + routingData.mPtrScores = bufferCast(*this->mPtrScoresDevice); + routingData.mPtrRoutingBias = bufferCast(*otherBiasDevice); + // Bias dtype is intentionally different from scores dtype (T) to test mixed-precision support. + // e.g. T=float → OtherT=bfloat16, T=bfloat16 → OtherT=float. + routingData.mDtypeBias = (sizeof(OtherT) == 4) ? btg::Dtype::Fp32 : btg::Dtype::Bfloat16; + routingData.mRouteScale = param.routedScalingFactor; + + // Run kernel — verifies it doesn't crash with mixed bias dtype + moe::dev::routing::routingCustom::run(routingData, this->mStream->get()); + this->mStream->synchronize(); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Tests for DeepSeek nGroup=1 via routingCustom (SigmoidBias + ScaledSumNormalize with routeScale != 1.0) +// When nGroup <= 1, DeepSeek routing is equivalent to SigmoidBias + ScaledSumNormalize, +// and production code routes through routingCustom (not routingDeepSeek). +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// DeepSeek nGroup=1: uses SigmoidBias policy (E512 × K8). + +TYPED_TEST(RoutingCustomKernelTest, DeepSeekNoGroupBlockLevel) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::MiniMax2) + .withNumTokens(4) + .withNumExperts(512) + .withTopK(8) + .withTileTokensDim(256) + .withRoutedScalingFactor(2.5f) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, DeepSeekNoGroupClusterLevel) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::MiniMax2) + .withNumTokens(100) + .withNumExperts(512) + .withTopK(8) + .withTileTokensDim(256) + .withRoutedScalingFactor(2.5f) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, DeepSeekNoGroupDeviceLevel) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::MiniMax2) + .withNumTokens(1000) + .withNumExperts(512) + .withTopK(8) + .withTileTokensDim(256) + .withRoutedScalingFactor(2.5f) + .withRequiredComputeCapability(8) + .build(); + this->runTest(param); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Tests for SigmoidRenorm (Sigmoid + SumNormalize) +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TYPED_TEST(RoutingCustomKernelTest, SigmoidRenormBlockLevel) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::SigmoidRenorm) + .withNumTokens(4) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, SigmoidRenormClusterLevel) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::SigmoidRenorm) + .withNumTokens(100) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +// SigmoidRenorm PolicyTraits: only E128 × K8. + +TYPED_TEST(RoutingCustomKernelTest, SigmoidRenormDeviceLevel) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::SigmoidRenorm) + .withNumTokens(1000) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(8) + .withRequiredComputeCapability(8) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, SigmoidRenormWithExpertParallelization) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::SigmoidRenorm) + .withNumTokens(10) + .withNumExperts(128) + .withTopK(8) + .withExpertParallelization(2, 1) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Tests for NoOp + NoOp (raw topK, no score transformation) +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TYPED_TEST(RoutingCustomKernelTest, NoOpBlockLevel) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withPreprocessType(RoutingPreprocessType::None) + .withPostprocessType(RoutingPostprocessType::None) + .withNumTokens(4) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, NoOpClusterLevel) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withPreprocessType(RoutingPreprocessType::None) + .withPostprocessType(RoutingPostprocessType::None) + .withNumTokens(100) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +// NoOp PolicyTraits: only E128 × K8. + +TYPED_TEST(RoutingCustomKernelTest, NoOpDeviceLevel) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withPreprocessType(RoutingPreprocessType::None) + .withPostprocessType(RoutingPostprocessType::None) + .withNumTokens(1000) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(8) + .withRequiredComputeCapability(8) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, NoOpWithExpertParallelization) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withPreprocessType(RoutingPreprocessType::None) + .withPostprocessType(RoutingPostprocessType::None) + .withNumTokens(10) + .withNumExperts(128) + .withTopK(8) + .withExpertParallelization(2, 1) + .withTileTokensDim(256) + .build(); + this->runTest(param); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Dynamic-block kernel tests (5-16 tokens, ≤512 experts) +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TYPED_TEST(RoutingCustomKernelTest, DynBlockBasic) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(8) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .withUsePdl(true) + .withGetExpWeights(true) + .withRequiredComputeCapability(9) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, DynBlockMaxTokens) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(16) + .withNumExperts(512) + .withTopK(8) + .withTileTokensDim(256) + .withUsePdl(true) + .withGetExpWeights(true) + .withRequiredComputeCapability(9) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, DynBlockWithExpertParallelization) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(12) + .withNumExperts(512) + .withTopK(8) + .withExpertParallelization(2, 1) + .withTileTokensDim(192) + .withUsePdl(true) + .withGetExpWeights(true) + .withRequiredComputeCapability(9) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, DynBlockWithTopKAsInput) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(8) + .withNumExperts(128) + .withTopK(8) + .withTileTokensDim(256) + .withUsePdl(true) + .withGetExpWeights(true) + .withUseTopKAsInput(true) + .withRequiredComputeCapability(9) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, DynBlockWithInvalidTopKInput) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Renormalize) + .withNumTokens(10) + .withNumExperts(512) + .withTopK(8) + .withExpertParallelization(2, 0) + .withTileTokensDim(256) + .withUsePdl(true) + .withGetExpWeights(true) + .withUseTopKAsInput(true) + .withHasInvalidTopKInput(true) + .withRequiredComputeCapability(9) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingCustomKernelTest, DynBlockWithRenormalizeNaive) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::RenormalizeNaive) + .withNumTokens(16) + .withNumExperts(512) + .withTopK(8) + .withTileTokensDim(256) + .withUsePdl(true) + .withGetExpWeights(true) + .withRequiredComputeCapability(9) + .build(); + this->runTest(param); +}; + +} // namespace diff --git a/cpp/tests/unit_tests/kernels/routing/routingDeepSeekTest.cpp b/cpp/tests/unit_tests/kernels/routing/routingDeepSeekTest.cpp index 0467c174965..e29ed24adbf 100644 --- a/cpp/tests/unit_tests/kernels/routing/routingDeepSeekTest.cpp +++ b/cpp/tests/unit_tests/kernels/routing/routingDeepSeekTest.cpp @@ -151,6 +151,7 @@ class RoutingDeepSeekKernelTest : public RoutingKernelTest } } +protected: void allocateBuffers(RoutingKernelTestParam const& param) { RoutingKernelTest::allocateBuffers(param); @@ -179,9 +180,11 @@ class RoutingDeepSeekKernelTest : public RoutingKernelTest void setParams(RoutingKernelTestParam const& param, RoutingData& routingData) { RoutingKernelTest::setCommonParams(param, routingData); - routingData.mDtypeExpW = btg::Dtype::Bfloat16; + routingData.mDtypeOutput = btg::Dtype::Bfloat16; routingData.mPtrRoutingBias = bufferCast(*this->mPtrRoutingBiasDevice); + // Bias dtype matches T (the test's type parameter) + routingData.mDtypeBias = (sizeof(T) == 4) ? btg::Dtype::Fp32 : btg::Dtype::Bfloat16; routingData.mNumExpertGroups = param.nGroup; routingData.mNumLimitedGroups = param.topkGroup; @@ -193,6 +196,12 @@ class RoutingDeepSeekKernelTest : public RoutingKernelTest routingData.mPtrTopKIds = bufferCast(*this->mPtrTopKIdsDevice); routingData.mPtrScores = nullptr; } + else if (param.useTopKPackedAsInput) + { + // mPtrTopKPacked is already set by setCommonParams; just clear scores and topKIds + routingData.mPtrTopKIds = nullptr; + routingData.mPtrScores = nullptr; + } else { routingData.mPtrTopKIds = nullptr; @@ -213,199 +222,377 @@ TYPED_TEST_SUITE(RoutingDeepSeekKernelTest, Bf16Types); TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization32) { - RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/4, // 1024 - /*numExperts=*/32, /*topK=*/8, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::DeepSeekV3) + .withNumTokens(4) + .withNumExperts(32) + .withTopK(8) + .withTileTokensDim(256) + .withNGroup(8) + .withTopkGroup(4) + .build(); this->runTest(param); }; TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization72) { - RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/4, // 1024 - /*numExperts=*/72, /*topK=*/6, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::DeepSeekV3) + .withNumTokens(4) + .withNumExperts(72) + .withTopK(6) + .withTileTokensDim(256) + .withNGroup(1) + .withTopkGroup(1) + .build(); this->runTest(param); }; TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization384) { - RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/4, // 1024 - /*numExperts=*/384, /*topK=*/8, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::DeepSeekV3) + .withNumTokens(4) + .withNumExperts(384) + .withTopK(8) + .withTileTokensDim(256) + .withNGroup(1) + .withTopkGroup(1) + .build(); this->runTest(param); }; TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization512) { - RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/4, // 1024 - /*numExperts=*/512, /*topK=*/22, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::DeepSeekV3) + .withNumTokens(4) + .withNumExperts(512) + .withTopK(22) + .withTileTokensDim(256) + .withNGroup(1) + .withTopkGroup(1) + .build(); this->runTest(param); }; TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization) { - RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/1024, // 10 - /*numExperts=*/256, /*topK=*/8, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::DeepSeekV3) + .withNumTokens(1024) + .withNumExperts(256) + .withTopK(8) + .withTileTokensDim(256) + .withNGroup(8) + .withTopkGroup(4) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingDeepSeekKernelTest, BlockLevelTopKAsInput) +{ + // Small token count -> single-block path in runPostTopKPipeline + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::DeepSeekV3) + .withNumTokens(4) + .withNumExperts(256) + .withTopK(8) + .withTileTokensDim(192) + .withUseTopKAsInput(true) + .withNGroup(8) + .withTopkGroup(4) + .build(); this->runTest(param); }; TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationWithTopKAsInput) { - RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/1024, // 10 - /*numExperts=*/256, /*topK=*/8, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/192, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true, - /*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::DeepSeekV3) + .withNumTokens(1024) + .withNumExperts(256) + .withTopK(8) + .withTileTokensDim(192) + .withUseTopKAsInput(true) + .withHasInvalidTopKInput(true) + .withNGroup(8) + .withTopkGroup(4) + .build(); this->runTest(param); }; TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationWithTopKAsInput384) { - RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/1024, // 10 - /*numExperts=*/384, /*topK=*/8, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::DeepSeekV3) + .withNumTokens(1024) + .withNumExperts(384) + .withTopK(8) + .withTileTokensDim(256) + .withUseTopKAsInput(true) + .withNGroup(1) + .withTopkGroup(1) + .build(); + this->runTest(param); +}; + +// --- Tests for useTopKPackedAsInput (mPtrTopKPacked without mPtrScores) --- +// These test the runPostTopKPipeline path for the packed input format. + +TYPED_TEST(RoutingDeepSeekKernelTest, BlockLevelTopKPackedAsInput) +{ + // Small token count -> single-block path in runPostTopKPipeline + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::DeepSeekV3) + .withNumTokens(4) + .withNumExperts(256) + .withTopK(8) + .withTileTokensDim(192) + .withUseTopKPackedAsInput(true) + .withNGroup(8) + .withTopkGroup(4) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelTopKPackedAsInput) +{ + // Medium token count -> single-cluster path in runPostTopKPipeline + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::DeepSeekV3) + .withNumTokens(100) + .withNumExperts(256) + .withTopK(8) + .withTileTokensDim(192) + .withUseTopKPackedAsInput(true) + .withNGroup(8) + .withTopkGroup(4) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingDeepSeekKernelTest, DeviceLevelTopKPackedAsInput) +{ + // Large token count -> coop or multi-kernel path in runPostTopKPipeline + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::DeepSeekV3) + .withNumTokens(2048) + .withNumExperts(256) + .withTopK(8) + .withTileTokensDim(256) + .withUseTopKPackedAsInput(true) + .withNGroup(8) + .withTopkGroup(4) + .withRequiredComputeCapability(8) + .build(); this->runTest(param); }; TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationWithExpertParallelization) { - RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/100, - /*numExperts=*/256, /*topK=*/8, - /*expertParallelization=*/2, /*expertParallelizationId=*/1, /*tileTokensDim=*/192, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::DeepSeekV3) + .withNumTokens(100) + .withNumExperts(256) + .withTopK(8) + .withExpertParallelization(2, 1) + .withTileTokensDim(192) + .withNGroup(8) + .withTopkGroup(4) + .build(); this->runTest(param); }; +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Test DeepSeek main kernel with float32 bias (T=bf16 for scores output, but bias is float32). +// This exercises the loadScalar path with mismatched bias dtype. +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelWithFloat32Bias) +{ + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::DeepSeekV3) + .withNumTokens(100) + .withNumExperts(256) + .withTopK(8) + .withTileTokensDim(192) + .withNGroup(8) + .withTopkGroup(4) + .build(); + + this->allocateBuffers(param); + + // Override: allocate bias as float32 instead of T (bf16) + auto float32BiasHost + = this->mBufferManager->pinned(ITensor::makeShape({param.numExperts}), nvinfer1::DataType::kFLOAT); + auto float32BiasDevice + = this->mBufferManager->gpu(ITensor::makeShape({param.numExperts}), nvinfer1::DataType::kFLOAT); + auto biasPtr = bufferCast(*float32BiasHost); + for (int i = 0; i < param.numExperts; i++) + { + biasPtr[i] = 0.01f * (i % 100); + } + this->mBufferManager->copy(*float32BiasHost, *float32BiasDevice); + + // Setup normal buffers (scores, etc.) + float* scoresHostPtr = bufferCast(*this->mPtrScoresHost); + initData(scoresHostPtr, param.numTokens * param.numExperts, 42); + this->mBufferManager->copy(*this->mPtrScoresHost, *this->mPtrScoresDevice); + this->mStream->synchronize(); + + // Setup routing data with float32 bias + moe::dev::routing::routingDeepSeek::Data routingData; + this->setCommonParams(param, routingData); + routingData.mDtypeOutput = btg::Dtype::Bfloat16; + routingData.mPtrScores = bufferCast(*this->mPtrScoresDevice); + routingData.mPtrRoutingBias = bufferCast(*float32BiasDevice); + routingData.mDtypeBias = btg::Dtype::Fp32; // Float32 bias with bf16 output + routingData.mNumExpertGroups = param.nGroup; + routingData.mNumLimitedGroups = param.topkGroup; + routingData.mRouteScale = param.routedScalingFactor; + routingData.mUseRoutingSoftmax = false; + + // Run kernel — verifies it doesn't crash with float32 bias + moe::dev::routing::routingDeepSeek::run(routingData, this->mStream->get()); + this->mStream->synchronize(); +}; + TYPED_TEST(RoutingDeepSeekKernelTest, CooperativeLevelParallelization) { - RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/1030, - /*numExperts=*/256, /*topK=*/8, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::DeepSeekV3) + .withNumTokens(1030) + .withNumExperts(256) + .withTopK(8) + .withTileTokensDim(256) + .withNGroup(8) + .withTopkGroup(4) + .withRequiredComputeCapability(10) + .build(); this->runTest(param); }; TYPED_TEST(RoutingDeepSeekKernelTest, CooperativeLevelParallelization384) { - RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/1030, - /*numExperts=*/384, /*topK=*/8, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::DeepSeekV3) + .withNumTokens(1030) + .withNumExperts(384) + .withTopK(8) + .withTileTokensDim(256) + .withNGroup(1) + .withTopkGroup(1) + .withRequiredComputeCapability(10) + .build(); this->runTest(param); }; TYPED_TEST(RoutingDeepSeekKernelTest, CooperativeLevelParallelization512) { - RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/1030, - /*numExperts=*/512, /*topK=*/22, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::DeepSeekV3) + .withNumTokens(1030) + .withNumExperts(512) + .withTopK(22) + .withTileTokensDim(256) + .withNGroup(1) + .withTopkGroup(1) + .withRequiredComputeCapability(10) + .build(); this->runTest(param); }; TYPED_TEST(RoutingDeepSeekKernelTest, DeviceLevelParallelization) { - RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/20300, - /*numExperts=*/256, /*topK=*/8, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true, - /*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::DeepSeekV3) + .withNumTokens(20300) + .withNumExperts(256) + .withTopK(8) + .withTileTokensDim(256) + .withUseTopKAsInput(true) + .withHasInvalidTopKInput(true) + .withNGroup(8) + .withTopkGroup(4) + .withRequiredComputeCapability(10) + .build(); this->runTest(param); }; TYPED_TEST(RoutingDeepSeekKernelTest, DeviceLevelParallelization384) { - RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/20300, - /*numExperts=*/384, /*topK=*/8, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10); - this->runTest(param); -}; - -TYPED_TEST(RoutingDeepSeekKernelTest, DeviceLevelParallelization512) -{ - RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/20300, - /*numExperts=*/512, /*topK=*/22, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 1, /*topkGroup*/ 1, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::DeepSeekV3) + .withNumTokens(20300) + .withNumExperts(384) + .withTopK(8) + .withTileTokensDim(256) + .withNGroup(1) + .withTopkGroup(1) + .withRequiredComputeCapability(10) + .build(); this->runTest(param); }; TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationTop2) { - RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/10, - /*numExperts=*/256, /*topK=*/2, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::DeepSeekV3) + .withNumTokens(10) + .withNumExperts(256) + .withTopK(2) + .withTileTokensDim(256) + .withNGroup(8) + .withTopkGroup(4) + .build(); this->runTest(param); }; TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationWithExpertParallelizationTop2) { - RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/100, - /*numExperts=*/256, /*topK=*/2, - /*expertParallelization=*/2, /*expertParallelizationId=*/1, /*tileTokensDim=*/192, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::DeepSeekV3) + .withNumTokens(100) + .withNumExperts(256) + .withTopK(2) + .withExpertParallelization(2, 1) + .withTileTokensDim(192) + .withNGroup(8) + .withTopkGroup(4) + .build(); this->runTest(param); }; TYPED_TEST(RoutingDeepSeekKernelTest, CooperativeLevelParallelizationTop2) { - RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/1030, - /*numExperts=*/256, /*topK=*/2, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::DeepSeekV3) + .withNumTokens(1030) + .withNumExperts(256) + .withTopK(2) + .withTileTokensDim(256) + .withNGroup(8) + .withTopkGroup(4) + .withRequiredComputeCapability(10) + .build(); this->runTest(param); }; TYPED_TEST(RoutingDeepSeekKernelTest, CooperativeLevelParallelizationTop8) { - RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/1030, - /*numExperts=*/32, /*topK=*/8, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true, - /*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::DeepSeekV3) + .withNumTokens(1030) + .withNumExperts(32) + .withTopK(8) + .withTileTokensDim(256) + .withUseTopKAsInput(true) + .withHasInvalidTopKInput(true) + .withNGroup(8) + .withTopkGroup(4) + .withRequiredComputeCapability(10) + .build(); this->runTest(param); }; } // namespace diff --git a/cpp/tests/unit_tests/kernels/routing/routingLlama4Test.cpp b/cpp/tests/unit_tests/kernels/routing/routingLlama4Test.cpp index 6c4b5032c66..f889a7b79db 100644 --- a/cpp/tests/unit_tests/kernels/routing/routingLlama4Test.cpp +++ b/cpp/tests/unit_tests/kernels/routing/routingLlama4Test.cpp @@ -63,28 +63,25 @@ class RoutingLlama4KernelTest : public RoutingKernelTest (a.score > b.score) || (a.score == b.score && a.idx < b.idx)); //@TODO: check if this is correct }); - // Apply sigmoid to the top-k scores + // Apply sigmoid to top-K scores, then store results. + // mPtrTopKPacked stores SIGMOID scores (matching what the scores-path kernels produce). + // The cluster/device kernels pass these through as-is to mPtrTopKWeights. for (int ie = 0; ie < param.topK; ++ie) { auto finalScore = 1.F / (1.F + std::exp(-expIdx[ie].score)); - expIdx[ie].score = static_cast(finalScore); - } - // convert back to io_dtype and store the topk expert results in hostData.mPtrTopKPacked - for (int ie = 0; ie < param.topK; ++ie) - { - PackedType si{static_cast(expIdx[ie].score), expIdx[ie].idx}; + PackedType si{static_cast(finalScore), expIdx[ie].idx}; reinterpret_cast(bufferCast(*this->mPtrTopKPackedHost))[it * param.topK + ie] = si; if (param.useTopKAsInput) { bufferCast(*this->mPtrTopKIdsHost)[it * param.topK + ie] = static_cast(expIdx[ie].idx); - bufferCast(*this->mPtrTopKWeightsHost)[it * param.topK + ie] = static_cast(expIdx[ie].score); + bufferCast(*this->mPtrTopKWeightsHost)[it * param.topK + ie] = static_cast(finalScore); } else if (param.getExpWeights) { - bufferCast(*this->mPtrTopKWeightsHost)[it * param.topK + ie] = static_cast(expIdx[ie].score); + bufferCast(*this->mPtrTopKWeightsHost)[it * param.topK + ie] = static_cast(finalScore); } } } @@ -102,7 +99,7 @@ class RoutingLlama4KernelTest : public RoutingKernelTest void setParams(RoutingKernelTestParam const& param, RoutingData& routingData) { RoutingKernelTest::setCommonParams(param, routingData); - routingData.mDtypeExpW = btg::Dtype::Bfloat16; + routingData.mDtypeOutput = btg::Dtype::Bfloat16; routingData.mPtrTopKPacked = reinterpret_cast(bufferCast(*this->mPtrTopKPackedDevice)); if (param.useTopKAsInput) @@ -110,6 +107,12 @@ class RoutingLlama4KernelTest : public RoutingKernelTest routingData.mPtrTopKIds = bufferCast(*this->mPtrTopKIdsDevice); routingData.mPtrScores = nullptr; } + else if (param.useTopKPackedAsInput) + { + // mPtrTopKPacked is already set above; just clear scores and topKIds + routingData.mPtrTopKIds = nullptr; + routingData.mPtrScores = nullptr; + } else { routingData.mPtrTopKIds = nullptr; @@ -130,69 +133,128 @@ TYPED_TEST_SUITE(RoutingLlama4KernelTest, Bf16Types); TYPED_TEST(RoutingLlama4KernelTest, WarpLevelParallelization) { - RoutingKernelTestParam param(RoutingMethodType::Llama4, /*numTokens=*/3, - /*numExperts=*/128, /*topK=*/1, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/8, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 0.0f, - /*requiredComputeCapability*/ 8); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Llama4) + .withNumTokens(3) + .withNumExperts(128) + .withTopK(1) + .withTileTokensDim(8) + .withRequiredComputeCapability(8) + .build(); this->runTest(param); }; TYPED_TEST(RoutingLlama4KernelTest, ClusterLevelParallelization) { - RoutingKernelTestParam param(RoutingMethodType::Llama4, /*numTokens=*/10, - /*numExperts=*/128, /*topK=*/1, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/8, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Llama4) + .withNumTokens(10) + .withNumExperts(128) + .withTopK(1) + .withTileTokensDim(8) + .build(); this->runTest(param); }; TYPED_TEST(RoutingLlama4KernelTest, DeviceLevelParallelization) { - RoutingKernelTestParam param(RoutingMethodType::Llama4, /*numTokens=*/300, - /*numExperts=*/128, /*topK=*/1, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/8, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 8); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Llama4) + .withNumTokens(300) + .withNumExperts(128) + .withTopK(1) + .withTileTokensDim(8) + .withRequiredComputeCapability(8) + .build(); this->runTest(param); }; TYPED_TEST(RoutingLlama4KernelTest, WarpLevelParallelizationTopKAsInput) { - RoutingKernelTestParam param(RoutingMethodType::Llama4, /*numTokens=*/3, - /*numExperts=*/128, /*topK=*/1, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/8, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 0.0f, - /*requiredComputeCapability*/ 8); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Llama4) + .withNumTokens(3) + .withNumExperts(128) + .withTopK(1) + .withTileTokensDim(8) + .withUseTopKAsInput(true) + .withRequiredComputeCapability(8) + .build(); this->runTest(param); }; TYPED_TEST(RoutingLlama4KernelTest, ClusterLevelParallelizationTopKAsInput) { - RoutingKernelTestParam param(RoutingMethodType::Llama4, /*numTokens=*/10, - /*numExperts=*/128, /*topK=*/1, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/8, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Llama4) + .withNumTokens(10) + .withNumExperts(128) + .withTopK(1) + .withTileTokensDim(8) + .withUseTopKAsInput(true) + .build(); this->runTest(param); }; TYPED_TEST(RoutingLlama4KernelTest, DeviceLevelParallelizationTopKAsInput) { - RoutingKernelTestParam param(RoutingMethodType::Llama4, /*numTokens=*/300, - /*numExperts=*/128, /*topK=*/1, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/8, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 8); + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Llama4) + .withNumTokens(300) + .withNumExperts(128) + .withTopK(1) + .withTileTokensDim(8) + .withUseTopKAsInput(true) + .withRequiredComputeCapability(8) + .build(); + this->runTest(param); +}; + +// --- Tests for useTopKPackedAsInput (mPtrTopKPacked without mPtrScores) --- +// For Llama4, the kernels apply sigmoid_accurate to packed scores, +// so the packed input path goes through Llama4-specific kernels (not runPostTopKPipeline). + +TYPED_TEST(RoutingLlama4KernelTest, WarpLevelTopKPackedAsInput) +{ + // Small token count -> warp-level kernel + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Llama4) + .withNumTokens(3) + .withNumExperts(128) + .withTopK(1) + .withTileTokensDim(8) + .withUseTopKPackedAsInput(true) + .withRequiredComputeCapability(8) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingLlama4KernelTest, ClusterLevelTopKPackedAsInput) +{ + // Medium token count -> cluster-level kernel + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Llama4) + .withNumTokens(10) + .withNumExperts(128) + .withTopK(1) + .withTileTokensDim(8) + .withUseTopKPackedAsInput(true) + .build(); + this->runTest(param); +}; + +TYPED_TEST(RoutingLlama4KernelTest, DeviceLevelTopKPackedAsInput) +{ + // Large token count -> multi-kernel pipeline + auto param = RoutingKernelTestParam() + .withRoutingMethod(RoutingMethodType::Llama4) + .withNumTokens(300) + .withNumExperts(128) + .withTopK(1) + .withTileTokensDim(8) + .withUseTopKPackedAsInput(true) + .withRequiredComputeCapability(8) + .build(); this->runTest(param); }; diff --git a/cpp/tests/unit_tests/kernels/routing/routingRenormalizeTest.cpp b/cpp/tests/unit_tests/kernels/routing/routingRenormalizeTest.cpp deleted file mode 100644 index a6fce8ce49c..00000000000 --- a/cpp/tests/unit_tests/kernels/routing/routingRenormalizeTest.cpp +++ /dev/null @@ -1,453 +0,0 @@ -/* - * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tests/unit_tests/kernels/routing/routingTest.h" - -namespace tk = tensorrt_llm::kernels; -namespace btg = batchedGemm::trtllm::gen; -using namespace tensorrt_llm::runtime; -using namespace tensorrt_llm::tests::kernels::routing; - -namespace -{ - -template -class RoutingRenormalizeKernelTest : public RoutingKernelTest -{ - -protected: - using RoutingKernelTest::mSeed; - using RoutingKernelTest::mStream; - using RoutingKernelTest::mBufferManager; - using typename RoutingKernelTest::PackedType; - -private: - // private methods - void computeTopKExperts(RoutingKernelTestParam const& param) override - { - for (int it = 0; it < param.numTokens; ++it) - { - PackedFloat expWeightsIdx[param.numExperts]; - PackedFloat expIdx[param.topK]; - float sum = float{0.0f}; - float maxScore = -std::numeric_limits::infinity(); - for (int ie = 0; ie < param.numExperts; ++ie) - { - float score; - int16_t newIdx = static_cast(ie); - score = static_cast(bufferCast(*this->mPtrScoresHost)[it * param.numExperts + ie]); - - if (param.doSoftmaxBeforeTopK && score > maxScore) - { - maxScore = score; - } - - PackedFloat si{static_cast(score), newIdx}; - expWeightsIdx[ie] = si; - } - - if (param.doSoftmaxBeforeTopK) - { - // Run softmax before topk - for (int ie = 0; ie < param.numExperts; ++ie) - { - expWeightsIdx[ie].score - = static_cast(std::exp(static_cast(expWeightsIdx[ie].score) - maxScore)); - sum += expWeightsIdx[ie].score; - } - - for (int ie = 0; ie < param.numExperts; ++ie) - { - float score = static_cast(expWeightsIdx[ie].score); - score /= sum; - expWeightsIdx[ie].score = static_cast(score); - } - } - - // Calculate the top-k scores and indices - std::partial_sort_copy(expWeightsIdx, expWeightsIdx + param.numExperts, expIdx, expIdx + param.topK, comp); - - if (param.doSoftmaxBeforeTopK) - { - // Normalize the value after the topk - if (param.normTopkProb) - { - float sum = float{0.0f}; - for (int ie = 0; ie < param.topK; ++ie) - { - sum += static_cast(expIdx[ie].score); - } - for (int ie = 0; ie < param.topK; ++ie) - { - float score = static_cast(expIdx[ie].score); - score /= sum; - expIdx[ie].score = static_cast(score); - } - } - } - else - { - // Perform softmax after topk - float sum = float{0.0f}; - float maxScore = -std::numeric_limits::infinity(); - float score; - for (int ie = 0; ie < param.topK; ++ie) - { - score = static_cast(expIdx[ie].score); - maxScore = score >= maxScore ? score : maxScore; - } - for (int ie = 0; ie < param.topK; ++ie) - { - score = static_cast(expIdx[ie].score) - maxScore; - score = std::exp(score); - sum += score; - } - for (int ie = 0; ie < param.topK; ++ie) - { - score = static_cast(expIdx[ie].score) - maxScore; - score = static_cast(std::exp(score)); - score /= sum; - expIdx[ie].score = static_cast(score); - } - } - - // convert back to io_dtype and store the topk expert results in hostData.mPtrTopKPacked - for (int ie = 0; ie < param.topK; ++ie) - { - // Set invalid topk indices for the first half of the topk - if (param.hasInvalidTopKInput && ie < param.topK / 2 + 1) - { - expIdx[ie].idx = -1; - } - - PackedType si{static_cast(expIdx[ie].score), expIdx[ie].idx}; - reinterpret_cast(bufferCast(*this->mPtrTopKPackedHost))[it * param.topK + ie] = si; - if (param.useTopKAsInput) - { - bufferCast(*this->mPtrTopKIdsHost)[it * param.topK + ie] - = static_cast(expIdx[ie].idx); - bufferCast(*this->mPtrTopKWeightsHost)[it * param.topK + ie] = static_cast(expIdx[ie].score); - } - else if (param.getExpWeights) - { - bufferCast(*this->mPtrTopKWeightsHost)[it * param.topK + ie] = static_cast(expIdx[ie].score); - } - } - } - } - - void allocateBuffers(RoutingKernelTestParam const& param) override - { - RoutingKernelTest::allocateBuffers(param); - int64_t scoresSize = param.numTokens * param.numExperts; - this->mPtrScoresHost = mBufferManager->pinned(ITensor::makeShape({scoresSize}), TRTDataType::value); - this->mPtrScoresDevice = mBufferManager->gpu(ITensor::makeShape({scoresSize}), TRTDataType::value); - } - - template - void setParams(RoutingKernelTestParam const& param, RoutingData& routingData) - { - RoutingKernelTest::setCommonParams(param, routingData); - - if (sizeof(T) == 4) - { - routingData.mDtypeExpW = btg::Dtype::Fp32; - } - else - { - routingData.mDtypeExpW = btg::Dtype::Bfloat16; - } - - // Special case for RenormalizeNaive - routingData.mDoSoftmaxBeforeTopK = param.routingMethod == RoutingMethodType::RenormalizeNaive; - routingData.mNormTopkProb = param.routingMethod == RoutingMethodType::RenormalizeNaive; - - if (param.useTopKAsInput) - { - routingData.mPtrTopKIds = bufferCast(*this->mPtrTopKIdsDevice); - routingData.mPtrScores = nullptr; - } - else - { - routingData.mPtrTopKIds = nullptr; - routingData.mPtrScores = bufferCast(*this->mPtrScoresDevice); - } - } - - void callTestedFunction( - RoutingKernelTestParam const& param, tensorrt_llm::runtime::ITensor::SharedPtr& workspaceDevice) override - { - moe::dev::routing::routingRenormalize::Data routingData; - setParams(param, routingData); - moe::dev::routing::routingRenormalize::run(routingData, mStream->get()); - } -}; - -TYPED_TEST_SUITE(RoutingRenormalizeKernelTest, FloatAndBf16Types); - -TYPED_TEST(RoutingRenormalizeKernelTest, BlockLevelParallelization) -{ - RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/4, - /*numExperts=*/128, /*topK=*/8, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); - this->runTest(param); -}; - -TYPED_TEST(RoutingRenormalizeKernelTest, BlockLevelParallelizationWithExpertParallelization) -{ - RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/4, - /*numExperts=*/128, /*topK=*/8, - /*expertParallelization=*/2, /*expertParallelizationId=*/1, /*tileTokensDim=*/192, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); - this->runTest(param); -}; - -TYPED_TEST(RoutingRenormalizeKernelTest, BlockLevelParallelizationWithInvalidTopKInput) -{ - RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/4, - /*numExperts=*/128, /*topK=*/8, - /*expertParallelization=*/2, /*expertParallelizationId=*/1, /*tileTokensDim=*/192, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); - this->runTest(param); -}; - -TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelization) -{ - RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/10, - /*numExperts=*/128, /*topK=*/8, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/192, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); - this->runTest(param); -}; - -TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithExpertParallelization) -{ - RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/100, - /*numExperts=*/128, /*topK=*/8, - /*expertParallelization=*/2, /*expertParallelizationId=*/1, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); - this->runTest(param); -}; - -TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithInvalidTopKInput) -{ - RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/100, - /*numExperts=*/128, /*topK=*/8, - /*expertParallelization=*/2, /*expertParallelizationId=*/1, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); - this->runTest(param); -}; - -TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithRenormalizeNaive) -{ - RoutingKernelTestParam param(RoutingMethodType::RenormalizeNaive, /*numTokens=*/10, - /*numExperts=*/128, /*topK=*/8, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); - this->runTest(param); -}; - -TYPED_TEST(RoutingRenormalizeKernelTest, DeviceLevelParallelization) -{ - RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/1000, - /*numExperts=*/512, /*topK=*/10, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/8, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 8); - this->runTest(param); -}; - -TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationTop4) -{ - RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/200, - /*numExperts=*/128, /*topK=*/4, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/8, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); - this->runTest(param); -}; - -TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithInvalidTopKInputTop4) -{ - RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/200, - /*numExperts=*/128, /*topK=*/4, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/8, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); - this->runTest(param); -}; - -TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithExpertParallelizationTop4) -{ - RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/256, - /*numExperts=*/128, /*topK=*/4, - /*expertParallelization=*/2, /*expertParallelizationId=*/1, /*tileTokensDim=*/8, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); - this->runTest(param); -}; - -TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithRenormalizeNaiveTop4) -{ - RoutingKernelTestParam param(RoutingMethodType::RenormalizeNaive, /*numTokens=*/10, - /*numExperts=*/128, /*topK=*/4, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/8, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); - this->runTest(param); -}; - -TYPED_TEST(RoutingRenormalizeKernelTest, DeviceLevelParallelizationTop4) -{ - RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/1000, - /*numExperts=*/128, /*topK=*/4, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 8); - this->runTest(param); -}; - -TYPED_TEST(RoutingRenormalizeKernelTest, DynBlockBasic) -{ - RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/8, - /*numExperts=*/128, /*topK=*/8, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); - this->runTest(param); -}; - -TYPED_TEST(RoutingRenormalizeKernelTest, DynBlockMaxTokens) -{ - RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/16, - /*numExperts=*/512, /*topK=*/8, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); - this->runTest(param); -}; - -TYPED_TEST(RoutingRenormalizeKernelTest, DynBlockWithExpertParallelization) -{ - RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/12, - /*numExperts=*/512, /*topK=*/8, - /*expertParallelization=*/2, /*expertParallelizationId=*/1, /*tileTokensDim=*/192, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); - this->runTest(param); -}; - -TYPED_TEST(RoutingRenormalizeKernelTest, DynBlockWithTopKAsInput) -{ - RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/8, - /*numExperts=*/128, /*topK=*/8, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); - this->runTest(param); -}; - -TYPED_TEST(RoutingRenormalizeKernelTest, DynBlockWithInvalidTopKInput) -{ - RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/10, - /*numExperts=*/512, /*topK=*/8, - /*expertParallelization=*/2, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); - this->runTest(param); -}; - -TYPED_TEST(RoutingRenormalizeKernelTest, DynBlockWithRenormalizeNaive) -{ - RoutingKernelTestParam param(RoutingMethodType::RenormalizeNaive, /*numTokens=*/16, - /*numExperts=*/512, /*topK=*/8, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); - this->runTest(param); -}; - -TYPED_TEST(RoutingRenormalizeKernelTest, BlockLevelParallelizationLargeN) -{ - RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/4, - /*numExperts=*/2048, /*topK=*/32, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); - this->runTest(param); -}; - -TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationLargeN) -{ - RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/100, - /*numExperts=*/2048, /*topK=*/32, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); - this->runTest(param); -}; - -TYPED_TEST(RoutingRenormalizeKernelTest, DeviceLevelParallelizationLargeN) -{ - RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/1000, - /*numExperts=*/2048, /*topK=*/32, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/false, /*hasInvalidTopKInput=*/false, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 8); - this->runTest(param); -}; - -TYPED_TEST(RoutingRenormalizeKernelTest, DeviceLevelParallelizationLargeNWithInvalidTopKInput) -{ - RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/1000, - /*numExperts=*/2048, /*topK=*/32, - /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*tileTokensDim=*/256, - /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*useTopKAsInput=*/true, /*hasInvalidTopKInput=*/true, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 8); - this->runTest(param); -}; -} // end namespace diff --git a/cpp/tests/unit_tests/kernels/routing/routingTest.cpp b/cpp/tests/unit_tests/kernels/routing/routingTest.cpp index a7da5a6d728..3b5747f2893 100644 --- a/cpp/tests/unit_tests/kernels/routing/routingTest.cpp +++ b/cpp/tests/unit_tests/kernels/routing/routingTest.cpp @@ -297,6 +297,24 @@ void RoutingKernelTest::verifyExpertRoutingIndices(RoutingKernelTestParam con EXPECT_EQ(checkSetEqual(ie, permutedIdx, permutedIdxTest, "permuted idx"), true); EXPECT_EQ(checkSetEqual(ie, tokenIdx, tokenIdxTest, "token idx"), true); } + + // Verify that invalid expert entries produce expandedIdxToPermutedIdx == -1. + // The loop above only checks valid experts (0..numExperts-1) and skips invalid entries. + if (param.hasInvalidTopKInput) + { + for (int it = 0; it < param.numTokens * param.topK; ++it) + { + int16_t const expertIdx = expIdxHostPtr[it].idx; + bool const isInvalid = (expertIdx < 0) || (expertIdx >= param.numExperts); + if (isInvalid) + { + int32_t const permIdxTest = hostExpToPermTest[it]; + EXPECT_EQ(permIdxTest, -1) + << "expandedIdxToPermutedIdx[" << it << "] should be -1 for invalid expertId=" << expertIdx + << " but got " << permIdxTest; + } + } + } } template @@ -326,9 +344,18 @@ void RoutingKernelTest::verifyResult(RoutingKernelTestParam const& param) } // expert counts aren't always used, but if tokens > 8 * 1024, we are sure they are used if (param.numTokens > param.singleClusterTokenNum) - { //@Todo: check if this is always true + { assertEqual(bufferCast(*mPtrExpertCountsHost), expertCountsPtr, param.numExperts, "expert counts"); - if (param.routingMethod != RoutingMethodType::DeepSeekV3) + // The second half of mPtrExpertCounts is only filled by the multi-kernel offsets pipeline + // (routingIndicesOffsetsKernel). It is NOT filled by the coop kernel or cluster kernel. + // On SM90+, both the scores path (RoutingCustom.cu) and the post-topK path + // (RoutingFromTopKIds.cu) may use the coop kernel instead of multi-kernel for medium + // token counts. Skip this check whenever the coop path could have been taken. + // The coop path requires SM90+ and numExperts <= 1024. + bool const coopMayBeUsed = (mDeviceProp.major >= 9) && (param.numExperts <= 1024); + bool const useMultiKernelPath = !param.useTopKAsInput && !param.useTopKPackedAsInput + && param.routingMethod != RoutingMethodType::DeepSeekV3 && !coopMayBeUsed; + if (useMultiKernelPath) { assertEqual(bufferCast(*mPtrExpertCountsHost), expertCountsPtr + param.numExperts, param.numExperts, "expert counts (2)"); @@ -370,6 +397,13 @@ void RoutingKernelTest::runTest(RoutingKernelTestParam const& param) mBufferManager->copy(*mPtrTopKWeightsHost, *mPtrTopKWeightsDevice); mStream->synchronize(); } + else if (param.useTopKPackedAsInput) + { + // Set the topk_packed as input (computed by host reference, no scores) + mBufferManager->copy(*mPtrTopKPackedHost, *mPtrTopKPackedDevice); + mBufferManager->copy(*mPtrTopKWeightsHost, *mPtrTopKWeightsDevice); + mStream->synchronize(); + } // Retrieve the workspace size of the routing kernel. auto const workspaceSize = getDeviceWorkspaceSize(param); TensorPtr workspaceDevice diff --git a/cpp/tests/unit_tests/kernels/routing/routingTest.h b/cpp/tests/unit_tests/kernels/routing/routingTest.h index 4a50cbbc795..aed28023292 100644 --- a/cpp/tests/unit_tests/kernels/routing/routingTest.h +++ b/cpp/tests/unit_tests/kernels/routing/routingTest.h @@ -36,6 +36,8 @@ typedef testing::Types FloatAndBf16Types; typedef testing::Types<__nv_bfloat16> Bf16Types; using RoutingMethodType = tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::Routing::RoutingMethodType; +using RoutingPreprocessType = moe::dev::routing::RoutingPreprocessType; +using RoutingPostprocessType = moe::dev::routing::RoutingPostprocessType; using TensorPtr = tensorrt_llm::runtime::ITensor::SharedPtr; using namespace tensorrt_llm::runtime; @@ -226,14 +228,13 @@ inline auto comp = [](PackedFloat const& a, PackedFloat const& b) struct RoutingKernelTestParam { RoutingMethodType routingMethod{RoutingMethodType::Renormalize}; - int32_t numTokens; - int32_t numExperts; + int32_t numTokens{0}; + int32_t numExperts{0}; uint32_t topK{1}; int32_t localExpertsStartIdx{0}; int32_t localExpertsStrideLog2{0}; - // we don't use any special striding, and we always test the GPU at logical idx 0 - int32_t numLocalExperts{128}; + int32_t numLocalExperts{0}; int32_t paddingLog2{3}; int32_t tileTokensDim{1}; @@ -244,13 +245,19 @@ struct RoutingKernelTestParam int requiredComputeCapability{9}; // Check the input parameters - bool useTopKAsInput{false}; + bool useTopKAsInput{false}; // When true, mPtrTopKIds + mPtrTopKWeights are provided as input + bool useTopKPackedAsInput{false}; // When true, mPtrTopKPacked is provided as input (without mPtrScores) bool hasInvalidTopKInput{false}; + int32_t invalidExpertIdValue{-1}; // Value used to mark invalid topK entries: -1 or numExperts // Special for renormalize routing method bool doSoftmaxBeforeTopK{false}; bool normTopkProb{true}; + // Policy type selection for routingCustom (set automatically by build() if not overridden) + RoutingPreprocessType preprocessType{RoutingPreprocessType::None}; + RoutingPostprocessType postprocessType{RoutingPostprocessType::Softmax}; + // Special for deepseek routing method int32_t nGroup{0}; int32_t topkGroup{0}; @@ -259,59 +266,239 @@ struct RoutingKernelTestParam // Default constructor RoutingKernelTestParam() = default; - // Constructor with required parameters - RoutingKernelTestParam(int32_t nt, int32_t ne, uint32_t tk = 1) - : numTokens(nt) - , numExperts(ne) - , topK(tk) - { - } - - // Constructor with all parameters - RoutingKernelTestParam(RoutingMethodType routingMethod, int32_t numTokens, int32_t numExperts, uint32_t topK, - int32_t expertParallelization = 1, int32_t expertParallelizationId = 0, int32_t tileTokensDim = 1, - int32_t paddingLog2 = 3, int32_t localExpertsStrideLog2 = 0, bool usePdl = true, bool getExpWeights = true, - bool useTopKAsInput = false, bool hasInvalidTopKInput = false, int32_t nGroup = 1, int32_t topkGroup = 1, - float routedScalingFactor = 1.0f, int requiredComputeCapability = 9) - : routingMethod(routingMethod) - , numTokens(numTokens) - , numExperts(numExperts) - , topK(topK) - , tileTokensDim(tileTokensDim) - , paddingLog2(paddingLog2) - , localExpertsStrideLog2(localExpertsStrideLog2) - , usePdl(usePdl) - , getExpWeights(getExpWeights) - , useTopKAsInput(useTopKAsInput) - , hasInvalidTopKInput(hasInvalidTopKInput) - , nGroup(nGroup) - , topkGroup(topkGroup) - , routedScalingFactor(routedScalingFactor) - , requiredComputeCapability(requiredComputeCapability) - { - // Check the routing method - if (routingMethod != RoutingMethodType::Renormalize && routingMethod != RoutingMethodType::RenormalizeNaive - && routingMethod != RoutingMethodType::Llama4 && routingMethod != RoutingMethodType::DeepSeekV3) + // Copy / move constructors and assignment operators + RoutingKernelTestParam(RoutingKernelTestParam const& other) = default; + RoutingKernelTestParam(RoutingKernelTestParam&& other) = default; + RoutingKernelTestParam& operator=(RoutingKernelTestParam const& other) = default; + RoutingKernelTestParam& operator=(RoutingKernelTestParam&& other) = default; + ~RoutingKernelTestParam() = default; + + // + // Fluent builder methods — each returns *this so calls can be chained. + // Usage: + // auto param = RoutingKernelTestParam() + // .withRoutingMethod(RoutingMethodType::Renormalize) + // .withNumTokens(4) + // .withNumExperts(128) + // .withTopK(8) + // .build(); + // + + RoutingKernelTestParam& withRoutingMethod(RoutingMethodType val) + { + routingMethod = val; + return *this; + } + + RoutingKernelTestParam& withNumTokens(int32_t val) + { + numTokens = val; + return *this; + } + + RoutingKernelTestParam& withNumExperts(int32_t val) + { + numExperts = val; + return *this; + } + + RoutingKernelTestParam& withTopK(uint32_t val) + { + topK = val; + return *this; + } + + RoutingKernelTestParam& withExpertParallelization(int32_t ep, int32_t epId = 0) + { + mExpertParallelization = ep; + mExpertParallelizationId = epId; + return *this; + } + + RoutingKernelTestParam& withTileTokensDim(int32_t val) + { + tileTokensDim = val; + return *this; + } + + RoutingKernelTestParam& withPaddingLog2(int32_t val) + { + paddingLog2 = val; + return *this; + } + + RoutingKernelTestParam& withLocalExpertsStrideLog2(int32_t val) + { + localExpertsStrideLog2 = val; + return *this; + } + + RoutingKernelTestParam& withUsePdl(bool val) + { + usePdl = val; + return *this; + } + + RoutingKernelTestParam& withGetExpWeights(bool val) + { + getExpWeights = val; + return *this; + } + + RoutingKernelTestParam& withUseTopKAsInput(bool val) + { + useTopKAsInput = val; + return *this; + } + + RoutingKernelTestParam& withUseTopKPackedAsInput(bool val) + { + useTopKPackedAsInput = val; + return *this; + } + + RoutingKernelTestParam& withHasInvalidTopKInput(bool val) + { + hasInvalidTopKInput = val; + return *this; + } + + RoutingKernelTestParam& withInvalidExpertIdValue(int32_t val) + { + invalidExpertIdValue = val; + return *this; + } + + RoutingKernelTestParam& withNGroup(int32_t val) + { + nGroup = val; + return *this; + } + + RoutingKernelTestParam& withTopkGroup(int32_t val) + { + topkGroup = val; + return *this; + } + + RoutingKernelTestParam& withRoutedScalingFactor(float val) + { + routedScalingFactor = val; + return *this; + } + + RoutingKernelTestParam& withPreprocessType(RoutingPreprocessType val) + { + preprocessType = val; + mPreprocessTypeOverridden = true; + return *this; + } + + RoutingKernelTestParam& withPostprocessType(RoutingPostprocessType val) + { + postprocessType = val; + mPostprocessTypeOverridden = true; + return *this; + } + + RoutingKernelTestParam& withNormTopkProb(bool val) + { + normTopkProb = val; + mNormTopkProbOverridden = true; + return *this; + } + + RoutingKernelTestParam& withRequiredComputeCapability(int val) + { + requiredComputeCapability = val; + return *this; + } + + /// Finalize and validate. Must be called after all `with*()` setters. + RoutingKernelTestParam& build() + { + // Validate routing method + if (routingMethod != RoutingMethodType::Default && routingMethod != RoutingMethodType::Renormalize + && routingMethod != RoutingMethodType::RenormalizeNaive && routingMethod != RoutingMethodType::Llama4 + && routingMethod != RoutingMethodType::DeepSeekV3 && routingMethod != RoutingMethodType::MiniMax2 + && routingMethod != RoutingMethodType::SigmoidRenorm) { throw std::invalid_argument("Invalid routing method"); } - // Set about the expert parallelization - numLocalExperts = numExperts / expertParallelization; - localExpertsStartIdx = numLocalExperts * expertParallelizationId; + // Derive expert parallelization parameters + numLocalExperts = numExperts / mExpertParallelization; + localExpertsStartIdx = numLocalExperts * mExpertParallelizationId; - // Apply routing method specific settings - if (routingMethod == RoutingMethodType::RenormalizeNaive) + // Apply routing-method-specific settings + if (routingMethod == RoutingMethodType::Default) + { + doSoftmaxBeforeTopK = true; + normTopkProb = false; + } + else if (routingMethod == RoutingMethodType::RenormalizeNaive) { doSoftmaxBeforeTopK = true; normTopkProb = true; } + else if (routingMethod == RoutingMethodType::SigmoidRenorm) + { + doSoftmaxBeforeTopK = false; + if (!mNormTopkProbOverridden) + { + normTopkProb = true; + } + } else { doSoftmaxBeforeTopK = false; normTopkProb = false; } + // Derive policy types from routing method when not explicitly set + if (!mPreprocessTypeOverridden) + { + if (routingMethod == RoutingMethodType::Default || routingMethod == RoutingMethodType::RenormalizeNaive) + { + preprocessType = RoutingPreprocessType::Softmax; + } + else if (routingMethod == RoutingMethodType::MiniMax2) + { + preprocessType = RoutingPreprocessType::SigmoidBias; + } + else if (routingMethod == RoutingMethodType::SigmoidRenorm) + { + preprocessType = RoutingPreprocessType::Sigmoid; + } + else + { + preprocessType = RoutingPreprocessType::None; + } + } + if (!mPostprocessTypeOverridden) + { + if (routingMethod == RoutingMethodType::Default) + { + postprocessType = RoutingPostprocessType::None; + } + else if (routingMethod == RoutingMethodType::RenormalizeNaive) + { + postprocessType = RoutingPostprocessType::SumNormalize; + } + else if (routingMethod == RoutingMethodType::MiniMax2) + { + postprocessType = RoutingPostprocessType::ScaledSumNormalize; + } + else if (routingMethod == RoutingMethodType::SigmoidRenorm) + { + postprocessType = RoutingPostprocessType::SumNormalize; + } + else + { + postprocessType = RoutingPostprocessType::Softmax; + } + } + // Set singleClusterTokenNum if (routingMethod == RoutingMethodType::DeepSeekV3) { @@ -322,36 +509,36 @@ struct RoutingKernelTestParam singleClusterTokenNum = 256; } + // Cross-field validation if (hasInvalidTopKInput && !useTopKAsInput) { throw std::invalid_argument("hasInvalidTopKInput is only supported when useTopKAsInput is true"); } - } - - // Copy constructor - RoutingKernelTestParam(RoutingKernelTestParam const& other) = default; - - // Move constructor - RoutingKernelTestParam(RoutingKernelTestParam&& other) = default; - - // Copy assignment operator - RoutingKernelTestParam& operator=(RoutingKernelTestParam const& other) = default; - - // Move assignment operator - RoutingKernelTestParam& operator=(RoutingKernelTestParam&& other) = default; + if (useTopKAsInput && useTopKPackedAsInput) + { + throw std::invalid_argument("useTopKAsInput and useTopKPackedAsInput are mutually exclusive"); + } - // Destructor - ~RoutingKernelTestParam() = default; + return *this; + } std::string toString() const { return tensorrt_llm::common::fmtstr( "RoutingKernelTestParam[num_tokens=%d, num_experts=%d, topK=%u, doSoftmaxBeforeTopK=%d, normTopkProb=%d, " "localExpertsStartIdx=%d, localExpertsStrideLog2=%d, numLocalExperts=%d, usePdl=%d, useTopKAsInput=%d, " - "hasInvalidTopKInput=%d]", + "useTopKPackedAsInput=%d, hasInvalidTopKInput=%d]", numTokens, numExperts, topK, doSoftmaxBeforeTopK, normTopkProb, localExpertsStartIdx, - localExpertsStrideLog2, numLocalExperts, usePdl, useTopKAsInput, hasInvalidTopKInput); + localExpertsStrideLog2, numLocalExperts, usePdl, useTopKAsInput, useTopKPackedAsInput, hasInvalidTopKInput); } + +private: + // Builder state — used by build() to derive public fields. + int32_t mExpertParallelization{1}; + int32_t mExpertParallelizationId{0}; + bool mPreprocessTypeOverridden{false}; + bool mPostprocessTypeOverridden{false}; + bool mNormTopkProbOverridden{false}; }; template diff --git a/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py b/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py index 0addc29a566..66b83df12a4 100644 --- a/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py @@ -11,6 +11,7 @@ get_last_power_of_2_num_tokens_buckets, last_positive_power_of_2, next_positive_power_of_2) +from tensorrt_llm._utils import get_sm_version from ..autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec, OptimizationProfile, TunableRunner, TuningConfig) @@ -89,6 +90,18 @@ def prepare_dummy_topk_and_hook( lambda: torch.randn( num_experts, dtype=torch.bfloat16, device=hidden_states.device) }) + if routing_method_type == RoutingMethodType.MiniMax2: + routing_cls_kwargs.update({ + 'callable_e_score_correction_bias': + lambda: torch.randn( + num_experts, dtype=torch.bfloat16, device=hidden_states.device), + 'num_experts': + num_experts, + }) + if routing_method_type == RoutingMethodType.SigmoidRenorm: + routing_cls_kwargs.update({ + 'num_experts': num_experts, + }) routing_method = ROUTING_METHOD_TYPE_TO_CLASS[routing_method_type]( top_k=top_k, **routing_cls_kwargs) @@ -684,10 +697,27 @@ def _constrain_to_num_tokens(shapes: Tuple[torch.Size]) -> int: return num_tokens HS_SCALE_IDX = 3 - CONSTRAINED_HS_SCALE_DIM = 1 - constraint_hidden_states_scale = ConstraintSpec( - HS_SCALE_IDX, CONSTRAINED_HS_SCALE_DIM, _constrain_to_num_tokens) + if get_sm_version() >= 100: + # SM100+: fp8_quantize_1x128 returns 2D scale (blocked_n, num_tokens) + CONSTRAINED_HS_SCALE_DIM = 1 + constraint_hidden_states_scale = ConstraintSpec( + HS_SCALE_IDX, CONSTRAINED_HS_SCALE_DIM, + _constrain_to_num_tokens) + else: + # SM90: fp8_quantize_1x128 returns 1D scale with layout matching + # the fp8_quantize_1x128 custom op shape formula. + def _constrain_hs_scale_sm90(shapes: Tuple[torch.Size]) -> int: + num_tokens = shapes[2][0] + hidden_size = shapes[2][1] + pad_m = fp4_utils.pad_up(num_tokens, 4) + blocked_n = (hidden_size + 127) // 128 + return fp4_utils.pad_up(pad_m * blocked_n * 4, 128) // 4 + + CONSTRAINED_HS_SCALE_DIM = 0 + constraint_hidden_states_scale = ConstraintSpec( + HS_SCALE_IDX, CONSTRAINED_HS_SCALE_DIM, + _constrain_hs_scale_sm90) ROUTER_LOGITS_IDX = 0 CONSTRAINED_RL_DIM = 0 diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py index 23354f5a5b3..839b741f8f7 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py @@ -44,7 +44,8 @@ W4A16MXFP4TRTLLMGenFusedMoEMethod) # isort: on from .routing import (BaseMoeRoutingMethod, DeepSeekV3MoeRoutingMethod, - DefaultMoeRoutingMethod) + DefaultMoeRoutingMethod, MiniMaxM2MoeRoutingMethod, + SigmoidRenormMoeRoutingMethod) class TRTLLMGenFusedMoE(MoE): @@ -572,6 +573,18 @@ def run_moe( n_group = self.routing_method.routing_impl.n_group topk_group = self.routing_method.routing_impl.topk_group routed_scaling_factor = self.routing_method.routing_impl.routed_scaling_factor + elif isinstance(self.routing_method, MiniMaxM2MoeRoutingMethod): + top_k = self.routing_method.top_k + routing_bias = self.routing_method.e_score_correction_bias + n_group = None + topk_group = None + routed_scaling_factor = None + elif isinstance(self.routing_method, SigmoidRenormMoeRoutingMethod): + top_k = self.routing_method.top_k + routing_bias = None + n_group = None + topk_group = None + routed_scaling_factor = None else: top_k = self.routing_method.top_k routing_bias = None @@ -594,7 +607,7 @@ def run_moe( if self.has_deepseek_fp8_block_scales: assert do_finalize, "fp8_block_scale_moe_runner does not support do_finalize=False" - # fp8_block_scale_moe_runner needs 2D shape for x_sf and only support SM100+ + # fp8_quantize_1x128 returns 2D x_sf on SM100+, 1D on SM90 if x_sf is None: x, x_sf = torch.ops.trtllm.fp8_quantize_1x128(x) result = self.op_backend.run_fp8_block_scale_moe( @@ -1069,8 +1082,11 @@ def forward_fake( else: is_deepseek_v3_routing = isinstance(self.routing_method, DeepSeekV3MoeRoutingMethod) + is_minimax_routing = isinstance(self.routing_method, + MiniMaxM2MoeRoutingMethod) top_k = self.routing_method.routing_impl.top_k if is_deepseek_v3_routing else self.routing_method.top_k - routing_bias = self.routing_method.e_score_correction_bias if is_deepseek_v3_routing else None + routing_bias = self.routing_method.e_score_correction_bias if ( + is_deepseek_v3_routing or is_minimax_routing) else None return fp4_block_scale_fake_output_without_finalize( x, self.num_experts, diff --git a/tensorrt_llm/_torch/modules/fused_moe/routing.py b/tensorrt_llm/_torch/modules/fused_moe/routing.py index 69498c96cfc..e8f876e85eb 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/routing.py +++ b/tensorrt_llm/_torch/modules/fused_moe/routing.py @@ -157,8 +157,10 @@ class RoutingMethodType(IntEnum): RenormalizeNaive = 4, # MiniMaxM2: Sigmoid -> RoutingBiasAdd -> TopK -> Renormalize(without bias) MiniMax2 = 5, + # SigmoidRenorm: Sigmoid -> TopK -> Renormalize + SigmoidRenorm = 6, # Unspecified - Unspecified = 6, + Unspecified = 7, class BaseMoeRoutingMethod(nn.Module): @@ -437,6 +439,38 @@ def routing_method_type(self): return RoutingMethodType.MiniMax2 +class SigmoidRenormMoeRoutingMethod(BaseMoeRoutingMethod): + + def __init__( + self, + top_k: int, + num_experts: int, + renormalize: bool = True, + output_dtype: torch.dtype = torch.float32, + ): + super().__init__() + self.top_k = top_k + self.num_experts = num_experts + self.renormalize = renormalize + self.output_dtype = output_dtype + + def apply(self, + router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + scores = torch.sigmoid(router_logits) + topk_weights, topk_idx = torch.topk(scores, + k=self.top_k, + dim=-1, + sorted=False) + if self.renormalize: + topk_weights = topk_weights / ( + topk_weights.sum(dim=-1, keepdim=True) + 1e-20) + return topk_idx.to(torch.int32), topk_weights.to(self.output_dtype) + + @property + def routing_method_type(self): + return RoutingMethodType.SigmoidRenorm + + class RenormalizeMoeRoutingMethod(BaseMoeRoutingMethod): def __init__( @@ -647,6 +681,8 @@ def routing_method_type(self) -> RoutingMethodType: BaseMoeRoutingMethod, RoutingMethodType.MiniMax2: MiniMaxM2MoeRoutingMethod, + RoutingMethodType.SigmoidRenorm: + SigmoidRenormMoeRoutingMethod, } diff --git a/tests/unittest/_torch/modules/moe/moe_test_utils.py b/tests/unittest/_torch/modules/moe/moe_test_utils.py index 604b42b56fa..7d1c7668fc9 100644 --- a/tests/unittest/_torch/modules/moe/moe_test_utils.py +++ b/tests/unittest/_torch/modules/moe/moe_test_utils.py @@ -164,25 +164,22 @@ def should_skip_trtllm( return None # Routing method compatibility check (used by test_moe_module.py) - # TRTLLMGen C++ routing kernel (runner.cu) only implements: - # - DeepSeekV3 (requires float32 routing_logits) + # TRTLLMGen C++ routing kernel (runner.cu) implements: + # - DeepSeekV3 (nGroup<=1: SigmoidBias+ScaledSumNormalize; nGroup>1: full DeepSeek kernel) + # - SigmoidRenorm (sigmoid activation, sum-normalize) + # - MiniMax2 (sigmoid activation, bias-added selection, scaled sum-normalize) # - Llama4 (requires top_k=1) - # - Renormalize - # - RenormalizeNaive - # See: cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu:77-212 + # - Renormalize / RenormalizeNaive / Default (softmax-based) + # See: cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/runner.cu if routing_method_cls is not None: from tensorrt_llm._torch.modules.fused_moe import ( DeepSeekV3MoeRoutingMethod, - DefaultMoeRoutingMethod, Llama4RenormalizeMoeRoutingMethod, - MiniMaxM2MoeRoutingMethod, ) # Routing methods NOT implemented in C++ kernel - trtllm_unimplemented_routing = ( - DefaultMoeRoutingMethod, # runner.cu:210 - "Unimplemented routing method" - MiniMaxM2MoeRoutingMethod, # runner.cu:210 - "Unimplemented routing method" - ) + # (Currently all routing methods are supported by runner.cu) + trtllm_unimplemented_routing = () if routing_method_cls in trtllm_unimplemented_routing: routing_name = routing_method_cls.__name__ return ( diff --git a/tests/unittest/_torch/modules/moe/test_moe_module.py b/tests/unittest/_torch/modules/moe/test_moe_module.py index 3c77e495775..bb221cb1204 100644 --- a/tests/unittest/_torch/modules/moe/test_moe_module.py +++ b/tests/unittest/_torch/modules/moe/test_moe_module.py @@ -257,9 +257,7 @@ def _run_autotune_test( _ = run_forward_fn() # Check if we should run full tactic replay - if not run_all_tactics or not supports_autotuner_capture( - backend_type, quant_algo, use_flashinfer - ): + if not run_all_tactics or not supports_autotuner_capture(backend_type, quant_algo): # Simple accuracy check for unsupported backends or when run_all_tactics is False with torch.inference_mode(): output = run_forward_fn() @@ -1151,7 +1149,7 @@ def test_configurable_moe_single_gpu( comm_methods=COMM_METHODS, swiglu_combos=SWIGLU_COMBOS, model_configs=MOE_MODEL_CONFIGS, - seq_lens=[8] if IS_CI_MODE else SEQ_LENS, + seq_lens=[1, 8] if IS_CI_MODE else SEQ_LENS, dtypes=DTYPES, backend_types=BACKEND_TYPES, quant_algos=QUANT_ALGOS, diff --git a/tests/unittest/_torch/thop/serial/test_moe.py b/tests/unittest/_torch/thop/serial/test_moe.py index 53c70ee21c0..ca0f56eb585 100644 --- a/tests/unittest/_torch/thop/serial/test_moe.py +++ b/tests/unittest/_torch/thop/serial/test_moe.py @@ -307,6 +307,65 @@ def routing_reference_renormalize_naive(expert_logits, top_k, padding): return permute_info, scores +# Sigmoid -> add bias -> TopK (selection) -> gather original sigmoid -> Renormalize +def routing_reference_minimax(expert_logits, routing_bias, top_k, padding): + assert routing_bias is not None, \ + "routing_reference_minimax requires routing_bias (MiniMax2 routing uses sigmoid + bias for expert selection)" + routing_logits = expert_logits.to(dtype=torch.float, device='cuda') + scores = F.sigmoid(routing_logits) + scores_with_bias = scores + routing_bias.to(torch.float) + + _, topk_idx = torch.topk(scores_with_bias, + k=top_k, + dim=-1, + largest=True, + sorted=False) + + # Gather the original (unbiased) sigmoid scores for the chosen experts + top_k_weights = scores.gather(1, topk_idx) + # Renormalize + top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + + 1e-20) + top_k_weights = top_k_weights.to(expert_logits.dtype) + + # Build full score matrix (same format as other routing references) + new_mask = torch.zeros_like(scores) + new_mask.scatter_(-1, topk_idx, 1) + result_scores = torch.zeros_like(scores) + for i in range(topk_idx.shape[0]): + for j in range(topk_idx.shape[1]): + result_scores[i, topk_idx[i, j]] = top_k_weights[i, j] + + permute_info = routing_reference(result_scores, top_k, padding) + return permute_info, result_scores + + +# Sigmoid -> TopK -> Renormalize (no bias) +def routing_reference_cohere_sigmoid(expert_logits, top_k, padding): + routing_logits = expert_logits.to(dtype=torch.float, device='cuda') + scores = torch.sigmoid(routing_logits) + + topk_weights, topk_idx = torch.topk(scores, + k=top_k, + dim=-1, + largest=True, + sorted=False) + + # Renormalize + topk_weights = topk_weights / (topk_weights.sum(dim=-1, keepdim=True) + + 1e-20) + topk_weights = topk_weights.to(expert_logits.dtype) + + # Build full score matrix (same format as other routing references) + result_scores = torch.zeros_like(scores) + for i in range(topk_idx.shape[0]): + for j in range(topk_idx.shape[1]): + result_scores[i, topk_idx[i, j]] = topk_weights[i, j] + + permute_info = routing_reference(result_scores, top_k, padding) + return permute_info, result_scores + + def dequant_reference_dsfp8(input, scale, transpose_scale, block_m, block_n): input = input.to(torch.float) scale = scale.to(torch.float) @@ -880,10 +939,10 @@ def are_groups_valid(top_k_groups, n_groups): return True -@pytest.mark.skip( - reason= - "Deprecated: covered by tests/unittest/_torch/modules/moe/test_moe_backend.py and test_moe_module.py. Add new tests there." -) +# @pytest.mark.skip( +# reason= +# "Deprecated: covered by tests/unittest/_torch/modules/moe/test_moe_backend.py and test_moe_module.py. Add new tests there." +# ) @pytest.mark.skipif( getSMVersion() < 100 or getSMVersion() >= 110, reason="The kernel only supports Blackwell. Current SM is %d." % @@ -1018,10 +1077,10 @@ def run_moe_fp8_test(self, num_tokens: int, expert_info: Tuple[int, int, percent=0.925) -@pytest.mark.skip( - reason= - "Deprecated: covered by tests/unittest/_torch/modules/moe/test_moe_backend.py and test_moe_module.py. Add new tests there." -) +# @pytest.mark.skip( +# reason= +# "Deprecated: covered by tests/unittest/_torch/modules/moe/test_moe_backend.py and test_moe_module.py. Add new tests there." +# ) @pytest.mark.skipif( getSMVersion() < 100 or getSMVersion() >= 110, reason="The kernel only supports Blackwell. Current SM is %d." % @@ -1120,6 +1179,28 @@ class TestMoeFp4: "routing_method_type": RoutingMethodType.Renormalize }, id="RoutingRenormalize_large_experts"), + pytest.param( + { + "num_experts": 128, + "top_k": 8, + "n_groups": None, + "top_k_groups": None, + "routed_scaling": None, + "has_routing_bias": True, + "routing_method_type": RoutingMethodType.MiniMax2 + }, + id="RoutingMiniMax2"), + pytest.param( + { + "num_experts": 128, + "top_k": 8, + "n_groups": None, + "top_k_groups": None, + "routed_scaling": None, + "has_routing_bias": False, + "routing_method_type": RoutingMethodType.SigmoidRenorm + }, + id="RoutingSigmoidRenorm"), ], ) def test_autotune(self, num_tokens, hidden_size, intermediate_size, @@ -1215,6 +1296,28 @@ def test_autotune_fp8_fp4(self, num_tokens, hidden_size, intermediate_size, "routing_method_type": RoutingMethodType.Renormalize }, id="RoutingRenormalize_large_experts"), + pytest.param( + { + "num_experts": 128, + "top_k": 8, + "n_groups": None, + "top_k_groups": None, + "routed_scaling": None, + "has_routing_bias": True, + "routing_method_type": RoutingMethodType.MiniMax2 + }, + id="RoutingMiniMax2"), + pytest.param( + { + "num_experts": 128, + "top_k": 8, + "n_groups": None, + "top_k_groups": None, + "routed_scaling": None, + "has_routing_bias": False, + "routing_method_type": RoutingMethodType.SigmoidRenorm + }, + id="RoutingSigmoidRenorm"), ], ) @pytest.mark.parametrize("use_topk_as_input", [False, True], @@ -1383,10 +1486,12 @@ def run_moe_fp4_test(self, assert num_experts % n_groups == 0 assert top_k < (top_k_groups * num_experts / n_groups) - if routing_method_type == RoutingMethodType.DeepSeekV3: + if routing_method_type in (RoutingMethodType.DeepSeekV3, + RoutingMethodType.MiniMax2, + RoutingMethodType.SigmoidRenorm): expert_logits = torch.randn((num_tokens, num_experts), device='cuda').to(torch.float) - elif routing_method_type == RoutingMethodType.RenormalizeNaive or routing_method_type == RoutingMethodType.Renormalize: + else: expert_logits = torch.randn((num_tokens, num_experts), device='cuda').to(torch.bfloat16) @@ -1498,6 +1603,12 @@ def run_moe_fp4_test(self, elif routing_method_type == RoutingMethodType.RenormalizeNaive: permute_info, scores = routing_reference_renormalize_naive( expert_logits, top_k, padding) + elif routing_method_type == RoutingMethodType.MiniMax2: + permute_info, scores = routing_reference_minimax( + expert_logits, routing_bias, top_k, padding) + elif routing_method_type == RoutingMethodType.SigmoidRenorm: + permute_info, scores = routing_reference_cohere_sigmoid( + expert_logits, top_k, padding) args = moe_args(num_tokens, num_experts, @@ -1753,10 +1864,12 @@ def run_moe_fp8_fp4_test(self, num_tokens: int, hidden_size: int, "use_topk_as_input is tested only with routing_method_type=DeepSeekV3" ) - if routing_method_type == RoutingMethodType.DeepSeekV3: + if routing_method_type in (RoutingMethodType.DeepSeekV3, + RoutingMethodType.MiniMax2, + RoutingMethodType.SigmoidRenorm): expert_logits = torch.randn((num_tokens, num_experts), device='cuda').to(torch.float) - elif routing_method_type == RoutingMethodType.RenormalizeNaive or routing_method_type == RoutingMethodType.Renormalize: + else: expert_logits = torch.randn((num_tokens, num_experts), device='cuda').to(torch.bfloat16) @@ -1824,6 +1937,12 @@ def run_moe_fp8_fp4_test(self, num_tokens: int, hidden_size: int, elif routing_method_type == RoutingMethodType.RenormalizeNaive: permute_info, scores = routing_reference_renormalize_naive( expert_logits, top_k, padding) + elif routing_method_type == RoutingMethodType.MiniMax2: + permute_info, scores = routing_reference_minimax( + expert_logits, routing_bias, top_k, padding) + elif routing_method_type == RoutingMethodType.SigmoidRenorm: + permute_info, scores = routing_reference_cohere_sigmoid( + expert_logits, top_k, padding) args = moe_args(num_tokens, num_experts, hidden_size, intermediate_size, top_k, padding, hidden_states_fp8, None, @@ -1957,10 +2076,10 @@ def run_moe_fp8_fp4_test(self, num_tokens: int, hidden_size: int, percent=0.925) -@pytest.mark.skip( - reason= - "Deprecated: covered by tests/unittest/_torch/modules/moe/test_moe_backend.py and test_moe_module.py. Add new tests there." -) +# @pytest.mark.skip( +# reason= +# "Deprecated: covered by tests/unittest/_torch/modules/moe/test_moe_backend.py and test_moe_module.py. Add new tests there." +# ) @pytest.mark.skipif( getSMVersion() < 100 or getSMVersion() >= 110, reason="The kernel only supports Blackwell. Current SM is %d." % @@ -2186,10 +2305,10 @@ def test_moe_fp8_per_tensor_scale(num_tokens, hidden_size, intermediate_size, percent=0.925) -@pytest.mark.skip( - reason= - "Deprecated: covered by tests/unittest/_torch/modules/moe/test_moe_backend.py and test_moe_module.py. Add new tests there." -) +# @pytest.mark.skip( +# reason= +# "Deprecated: covered by tests/unittest/_torch/modules/moe/test_moe_backend.py and test_moe_module.py. Add new tests there." +# ) @pytest.mark.skipif( getSMVersion() != 100, reason="The kernel only supports Blackwell. Current SM is %d." %