diff --git a/csrc/fused_moe/trtllm_backend/routingDeepSeek/RoutingDeepSeekCommon.cuh b/csrc/fused_moe/trtllm_backend/routingDeepSeek/RoutingDeepSeekCommon.cuh new file mode 100644 index 0000000000..9e31a470f5 --- /dev/null +++ b/csrc/fused_moe/trtllm_backend/routingDeepSeek/RoutingDeepSeekCommon.cuh @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "flashinfer/trtllm/fused_moe/RoutingKernel.cuh" + +namespace moe::dev::routing { +namespace routingDeepSeek { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +static constexpr int NumNemotronExperts = 512; +static constexpr int NumKimiK2Experts = 384; +static constexpr int NumDeepseekExperts = 256; +static constexpr int MaxSupportedExpertCount = + std::max({NumNemotronExperts, NumKimiK2Experts, NumDeepseekExperts}); +static constexpr int NumTopGroupScores = 2; +static constexpr int MaxNumTopGroups = 4; +static constexpr int MaxNumGroups = 8; + +static constexpr int NumTop8Experts = 8; +static constexpr int NumTop22Experts = 22; +static constexpr int MaxSupportedTopExperts = 32; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +int constexpr getMaxNumExperts(int32_t numExperts) { + if (numExperts <= topk::MaxNumExpertsUnit) { + return topk::MaxNumExpertsUnit; + } else if (numExperts <= NumDeepseekExperts) { + return NumDeepseekExperts; + } else if (numExperts <= NumKimiK2Experts) { + return NumKimiK2Experts; + } else if (numExperts <= NumNemotronExperts) { + return NumNemotronExperts; + } else { + TLLM_LOG_ERROR("Unsupported numExperts"); + return 0; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Helper macro: dispatch on topK tier for a given numExperts tier. +#define LAUNCH_DEEPSEEK_WITH_TOPK(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + stream, extraFlag1, forceFloatInput, numExperts) \ + if (data.mTopK <= NumTop8Experts) { \ + LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT( \ + data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, \ + forceFloatInput, numExperts, NumTop8Experts); \ + } else if (data.mTopK <= NumTop22Experts) { \ + LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT( \ + data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, \ + forceFloatInput, numExperts, NumTop22Experts); \ + } else { \ + LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT( \ + data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag1, \ + forceFloatInput, numExperts, MaxSupportedTopExperts); \ + } + +#define LAUNCH_ROUTING_DEEPSEEK(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \ + extraFlag1, forceFloatInput) \ + if (data.mNumExperts <= topk::MaxNumExpertsUnit) { \ + LAUNCH_DEEPSEEK_WITH_TOPK(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \ + extraFlag1, forceFloatInput, topk::MaxNumExpertsUnit); \ + } else if (data.mNumExperts <= NumDeepseekExperts) { \ + LAUNCH_DEEPSEEK_WITH_TOPK(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \ + extraFlag1, forceFloatInput, NumDeepseekExperts); \ + } else if (data.mNumExperts <= NumKimiK2Experts) { \ + LAUNCH_DEEPSEEK_WITH_TOPK(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \ + extraFlag1, forceFloatInput, NumKimiK2Experts); \ + } else if (data.mNumExperts <= NumNemotronExperts) { \ + LAUNCH_DEEPSEEK_WITH_TOPK(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \ + extraFlag1, forceFloatInput, NumNemotronExperts); \ + } else { \ + TLLM_LOG_ERROR("Unsupported numExperts"); \ + } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace routingDeepSeek +} // namespace moe::dev::routing diff --git a/csrc/fused_moe/trtllm_backend/routingDeepSeek/launchClusterKernel.cu b/csrc/fused_moe/trtllm_backend/routingDeepSeek/launchClusterKernel.cu new file mode 100644 index 0000000000..930b0ecbf3 --- /dev/null +++ b/csrc/fused_moe/trtllm_backend/routingDeepSeek/launchClusterKernel.cu @@ -0,0 +1,61 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "RoutingDeepSeekCommon.cuh" + +namespace moe::dev::routing { +namespace routingDeepSeek { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) +__global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) + __launch_bounds__(KernelParams::MaxNumExperts) + routingIndicesClusterKernel(KernelParams params) { + using OutputT = typename KernelParams::OutputT; + + int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); + int32_t const clusterBlockRank = blockIdx.x; + + //@todo: try to move it into routingPermutation + // then wait on primary grid + if constexpr (KernelParams::UsePdl) { + cudaGridDependencySynchronize(); + } + routingPermutation(params, nullptr, warpIdx, clusterBlockRank); +} +#else +__global__ void routingIndicesClusterKernel(KernelParams params) { + assert(false && "routingIndicesClusterKernel is only supported on SM90+ architectures"); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void launchClusterKernel(Data& data, int numThreadsHist, void* stream) { + LAUNCH_ROUTING_DEEPSEEK(data, + /*coopLaunch=*/false, routingIndicesClusterKernel, NumBlocksPerCluster, + numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace routingDeepSeek +} // namespace moe::dev::routing diff --git a/csrc/fused_moe/trtllm_backend/routingDeepSeek/launchCoopKernel.cu b/csrc/fused_moe/trtllm_backend/routingDeepSeek/launchCoopKernel.cu new file mode 100644 index 0000000000..457f772fa1 --- /dev/null +++ b/csrc/fused_moe/trtllm_backend/routingDeepSeek/launchCoopKernel.cu @@ -0,0 +1,244 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "RoutingDeepSeekCommon.cuh" + +namespace moe::dev::routing { +namespace routingDeepSeek { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) +__global__ void __launch_bounds__(KernelParams::MaxNumExperts) + routingIndicesCoopKernel(KernelParams params) { + // number of experts is bounded by number of threads + int constexpr NumThreads = KernelParams::MaxNumExperts; + __shared__ int32_t __attribute((aligned(128))) smemExpertCount[NumThreads]; + __shared__ int32_t __attribute((aligned(128))) smemExpertOffset[NumThreads]; + // needed for the exclusive sum of token offsets + using Scan = cub::BlockScan; + __shared__ typename Scan::TempStorage tempStorage; + // 64 elements -> 128+ registers. Above that we may start to see spilling to local memory. + static constexpr int MaxExpandedIdxPerThread = 64; + + // Initialize grid. + cg::grid_group grid = cg::this_grid(); + // Note: the following is more efficient than grid.block_index() because we don't use y and z. + int32_t const gridBlockIdx = blockIdx.x; + int32_t const gridThreadIdx = NumThreads * gridBlockIdx + threadIdx.x; + int32_t const numBlocks = gridDim.x; + int32_t const numThreadsPerGrid = numBlocks * NumThreads; + + int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); + + auto expandedIdxSize = params.mNumTokens * params.mTopK; + + // pre-fill the counts with 0 + smemExpertCount[threadIdx.x] = 0; + __syncthreads(); + + // then wait on primary grid + if constexpr (KernelParams::UsePdl) { + cudaGridDependencySynchronize(); + } + + // each thread keeps has some number of "expanded indexes" assigned to it + // for each of these, we keep the associated expert and offset within expert in registers + int32_t expertIndexes[MaxExpandedIdxPerThread]; + int32_t expertOffsets[MaxExpandedIdxPerThread]; + auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2; + // In order to avoid a serialization LDG-ATOMS-LDG-ATOMS-..., we skip multiple iterations at a + // time, and branch between a fast path without bound checks and a slow path with bound checks. + int constexpr IterStride = 4; + static_assert(MaxExpandedIdxPerThread % IterStride == 0); + + // Define a lambda to avoid code duplication in both branches. + auto loopBody = [&](int ii, int expandedIdx) { + int32_t expertIdx = params.mPtrTopKIds != nullptr ? params.mPtrTopKIds[expandedIdx] + : params.mPtrTopKPacked[expandedIdx].idx; + expertIndexes[ii] = expertIdx; + // check whether this expert is local to our GPU at all and ignore if not + auto localExpertIdx = expertIdx - params.mLocalExpertsStartIdx; + auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent && + (localExpertIdx & ((1 << params.mLocalExpertsStrideLog2) - 1)) == 0; + expertOffsets[ii] = isLocalExpert ? atomicAdd(smemExpertCount + expertIdx, 1) : 0; + }; + +#pragma unroll + for (int32_t ii0 = 0; ii0 < MaxExpandedIdxPerThread; ii0 += IterStride) { + // Whether it's safe to do multiple iterations without bound checks. + bool const takeFastPath = (ii0 + IterStride) * numThreadsPerGrid <= expandedIdxSize; + if (takeFastPath) { +#pragma unroll + for (int32_t jj = 0; jj < IterStride; jj++) { + int const ii = ii0 + jj; + auto expandedIdx = static_cast(gridThreadIdx) + ii * numThreadsPerGrid; + loopBody(ii, expandedIdx); + } + } else { + bool doBreak = false; +#pragma unroll + for (int32_t jj = 0; jj < IterStride; jj++) { + int const ii = ii0 + jj; + auto expandedIdx = static_cast(gridThreadIdx) + ii * numThreadsPerGrid; + if (expandedIdx >= expandedIdxSize) { + doBreak = true; + break; + } + loopBody(ii, expandedIdx); + } + if (doBreak) { + break; + } + } + } + + // Make histogram (token counts per expert) available to all threads in the block. + __syncthreads(); + + // + // Each thread now represents one expert + // + + // Add the local bin count to the common bin count and get a per-CTA offset. + int32_t const localExpertCount = smemExpertCount[threadIdx.x]; + + int32_t blockExpertOffset = 0; + if (threadIdx.x < params.mNumExperts) { + blockExpertOffset = atomicAdd(¶ms.mPtrExpertCounts[threadIdx.x], localExpertCount); + } + + // Sync to wait for completion of the histogram reduction. + grid.sync(); + + // Get total count for this expert. + int32_t count = (threadIdx.x < params.mNumExperts) ? params.mPtrExpertCounts[threadIdx.x] : 0; + + // Note: the scan is redundant in all CTAs, but doing it in only 1 CTA would be worse for latency. + + // Compute the runtime config for projections + // Whether or not an expert is local is taken into account when smemExpertCount is computed + // so we do not need to take it into account here. + + int32_t numCta; + if constexpr (KernelParams::isPow2) { + numCta = divUpLog2(count, params.mPaddingLog2); + } else { + numCta = divUpTileN(count, params.mTileTokensDim); + } + + int32_t ctaOffset; + int32_t numNonExitingCtas; + Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas); + + for (int32_t cta = gridBlockIdx; cta < numCta; cta += numBlocks) { + const int32_t localExpertIdx = + (threadIdx.x - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2; + params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx; + int32_t mnLimit1; + int32_t mnLimit2; + if constexpr (KernelParams::isPow2) { + mnLimit1 = mulLog2(ctaOffset + cta + 1, params.mPaddingLog2); + mnLimit2 = mulLog2(ctaOffset, params.mPaddingLog2) + count; + } else { + mnLimit1 = mulTileN(ctaOffset + cta + 1, params.mTileTokensDim); + mnLimit2 = mulTileN(ctaOffset, params.mTileTokensDim) + count; + } + params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = min(mnLimit1, mnLimit2); + } + + // get the padded offset associated with this expert + int32_t offset; + if constexpr (KernelParams::isPow2) { + offset = mulLog2(ctaOffset, params.mPaddingLog2); + } else { + offset = mulTileN(ctaOffset, params.mTileTokensDim); + } + int32_t permutedIdxSize; + if constexpr (KernelParams::isPow2) { + permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); + } else { + permutedIdxSize = mulTileN(numNonExitingCtas, params.mTileTokensDim); + } + + // write out padded count + if (gridBlockIdx == 0 && warpIdx == NumThreads / WarpSize - 1 && cute::elect_one_sync()) { + params.mPtrPermutedIdxSize[0] = permutedIdxSize; + params.mPtrNumNonExitingCtas[0] = numNonExitingCtas; + } + + // write expert offsets to shared + smemExpertOffset[threadIdx.x] = offset + blockExpertOffset; + + // make expert offsets available to all threads + __syncthreads(); + + // trigger the secondary kernel when using PDL + // We can't do it earlier because FC1 depends on the mPtrCtaIdxXyToBatchIdx, + // mPtrCtaIdxXyToMnLimit, mPtrNumNonExitingCtas and mPtrTotalNumPaddedTokens + // TODO: this is not sufficient to ensure visibility in the next kernel! + if constexpr (KernelParams::UsePdl) { + cudaTriggerProgrammaticLaunchCompletion(); + } + +// each thread has the same "expanded indexes" assigned to it as above +// at this point, we know the final offsets of experts and the offsets within +// experts, which allows writing the final index values +#pragma unroll + for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ++ii) { + auto expandedIdx = static_cast(gridThreadIdx) + ii * numThreadsPerGrid; + if (expandedIdx >= expandedIdxSize) { + break; + } + auto expertIdx = expertIndexes[ii]; + // check whether this expert is local to our GPU at all + auto localExpertIdx = static_cast(expertIdx) - params.mLocalExpertsStartIdx; + auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent && + (localExpertIdx & ((1 << params.mLocalExpertsStrideLog2) - 1)) == 0; + auto tokenIdx = expandedIdx / params.mTopK; + auto permutedIdx = + isLocalExpert ? int32_t{smemExpertOffset[expertIdx]} + expertOffsets[ii] : int32_t{-1}; + if (params.mPtrExpandedIdxToPermutedIdx != nullptr) { + params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx; + } + if (params.mPtrPermutedIdxToExpandedIdx != nullptr && isLocalExpert) { + params.mPtrPermutedIdxToExpandedIdx[permutedIdx] = expandedIdx; + } + if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocalExpert) { + params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx; + } + } +} +#else +__global__ void routingIndicesCoopKernel(KernelParams params) { + assert(false && "routingIndicesCoopKernel is only supported on SM90+ architectures"); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void launchCoopKernel(Data& data, int numBlocksCoop, int numThreadsHist, void* stream) { + LAUNCH_ROUTING_DEEPSEEK(data, + /*coopLaunch=*/true, routingIndicesCoopKernel, numBlocksCoop, + numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace routingDeepSeek +} // namespace moe::dev::routing diff --git a/csrc/fused_moe/trtllm_backend/routingDeepSeek/launchHistogramKernel.cu b/csrc/fused_moe/trtllm_backend/routingDeepSeek/launchHistogramKernel.cu new file mode 100644 index 0000000000..935457277a --- /dev/null +++ b/csrc/fused_moe/trtllm_backend/routingDeepSeek/launchHistogramKernel.cu @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "RoutingDeepSeekCommon.cuh" + +namespace moe::dev::routing { +namespace routingDeepSeek { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void launchHistogramKernel(Data& data, int numBlocksHistogram, int numThreadsHist, void* stream) { + LAUNCH_ROUTING_DEEPSEEK(data, + /*coopLaunch=*/false, routingIndicesHistogramKernel, numBlocksHistogram, + numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace routingDeepSeek +} // namespace moe::dev::routing diff --git a/csrc/fused_moe/trtllm_backend/routingDeepSeek/launchInitExpertCounts.cu b/csrc/fused_moe/trtllm_backend/routingDeepSeek/launchInitExpertCounts.cu new file mode 100644 index 0000000000..c2b2616327 --- /dev/null +++ b/csrc/fused_moe/trtllm_backend/routingDeepSeek/launchInitExpertCounts.cu @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "RoutingDeepSeekCommon.cuh" + +namespace moe::dev::routing { +namespace routingDeepSeek { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void launchInitExpertCounts(Data& data, int numThreadsHist, void* stream) { + LAUNCH_ROUTING_DEEPSEEK(data, false, routingInitExpertCounts, + (2 * data.mNumExperts - 1) / numThreadsHist + 1, numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/false); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace routingDeepSeek +} // namespace moe::dev::routing diff --git a/csrc/fused_moe/trtllm_backend/routingDeepSeek/launchMainKernel.cu b/csrc/fused_moe/trtllm_backend/routingDeepSeek/launchMainKernel.cu new file mode 100644 index 0000000000..082e543181 --- /dev/null +++ b/csrc/fused_moe/trtllm_backend/routingDeepSeek/launchMainKernel.cu @@ -0,0 +1,263 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "RoutingDeepSeekCommon.cuh" + +namespace moe::dev::routing { +namespace routingDeepSeek { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void routingMainKernel(KernelParams params) { + // declare types + using OutputT = typename KernelParams::OutputT; + using InputT = typename KernelParams::InputT; + + // declare shared memory structure + // number of experts is bounded by number of threads + __shared__ float __attribute((aligned(128))) smemScoreSigmoid[KernelParams::MaxNumExperts]; + __shared__ float __attribute((aligned(128))) smemScoreBias[KernelParams::MaxNumExperts]; + // number of expert groups is bounded by number of warps + __shared__ float __attribute((aligned(128))) smemGroupScores[MaxNumGroups]; + + // needed for warp reduce + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + // for the final reduction of weight norm, only some lanes need to participate + int32_t laneIdx = threadIdx.x % WarpSize; + int32_t warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); + // warps outside the range of expert groups do not participate + if constexpr (KernelParams::UseGroups) { + if (warpIdx >= params.mNumExpertGroups) { + return; + } + } + + // note that for invalid scores, we simply use a negative value: + // they work well even with the compacted format used in topK, and + // sigmoid / bias activated scores cannot be negative + static constexpr float invalidScoreFloat = float{-INFINITY}; + const OutputT invalidScore = OutputT{invalidScoreFloat}; + + // load bias already; each warp represents one expert group + auto threadExpert = threadIdx.x; + bool expertSelected = threadExpert < params.mNumExperts; + if constexpr (KernelParams::UseGroups) { + threadExpert = warpIdx * params.mNumExpertsPerGroup + laneIdx; + expertSelected = laneIdx < params.mNumExpertsPerGroup; + } + auto scoreIdx = int64_t{blockIdx.x} * int64_t{params.mNumExperts} + threadExpert; + auto biasVal = expertSelected ? params.mPtrRoutingBias[threadExpert] : invalidScore; + + // initialize the mPtrExpertCounts + if (params.mPtrExpertCounts) { + int32_t globalThreadIdx = blockIdx.x * blockDim.x + threadIdx.x; + int32_t globalThreadStride = gridDim.x * blockDim.x; + int32_t expertCountsNum = 2 * params.mNumExperts; + initArr(globalThreadIdx, expertCountsNum, globalThreadStride, params.mPtrExpertCounts, 0); + } + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + // trigger the secondary kernel when using PDL, then wait on primary + if constexpr (KernelParams::UsePdl) { + cudaTriggerProgrammaticLaunchCompletion(); + cudaGridDependencySynchronize(); + } +#endif + + if (params.mPtrScores != nullptr) { + // get our assigned thread score; each warp represents one expert group + float score = + expertSelected ? static_cast(params.mPtrScores[scoreIdx]) : invalidScoreFloat; + // get the sigmoid score + // note that for invalid values, we simply use a negative value: + // sigmoig scores are always strictly positive + auto scoreSigmoid = sigmoid_accurate(score); + // write the sigmoid score to shared for later use + if (expertSelected) { + smemScoreSigmoid[threadExpert] = scoreSigmoid; + } + // get the score with bias + // note that with invalid values, because sigmoid is < 1 and bias is -1, + // we must get a negative value, which is smaller than any valid value + auto scoreBias = float{scoreSigmoid + float{biasVal}}; + + if (expertSelected) { + smemScoreBias[threadExpert] = scoreBias; + } + + // registers for top group score reduction + float topExpGroupScores[NumTopGroupScores]; + [[maybe_unused]] int32_t topExpGroupIdx[NumTopGroupScores]; + float topGroups[MaxNumTopGroups]; // bound of params.mNumLimitedGroups + int32_t topGroupIdx[MaxNumTopGroups]; + float expertScoreGroup[MaxNumTopGroups]; + int32_t expertIdxGroup[MaxNumTopGroups]; + float topScores[KernelParams::MaxNumTopExperts]; // bound of params.mTopK + int32_t topExperts[KernelParams::MaxNumTopExperts]; + + if constexpr (KernelParams::UseGroups) { + topk::reduceTopK(warp, topExpGroupScores, topExpGroupIdx, scoreBias, threadExpert, + /* minValue */ invalidScoreFloat); + // get the final group score and write it to shared + if (cute::elect_one_sync()) { + auto groupScore = topExpGroupScores[0] + topExpGroupScores[1]; + smemGroupScores[warpIdx] = groupScore; + } + } + + // make group scores available to all warps + __syncthreads(); + + auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2; + if constexpr (KernelParams::UseGroups) { // a single warp performs the selection of top groups, + // and goes on to select the final experts + if (warpIdx == 0) { + float groupScore = + laneIdx < params.mNumExpertGroups ? smemGroupScores[laneIdx] : invalidScoreFloat; + topk::reduceTopK(warp, topGroups, topGroupIdx, groupScore, laneIdx, + /* minValue */ invalidScoreFloat); + // final expert selection: get relevant indexes and scores from shared +#pragma unroll + for (int ii = 0; ii < MaxNumTopGroups; ++ii) { // bound of params.mNumLimitedGroups + auto groupIdx = topGroupIdx[ii]; + expertIdxGroup[ii] = groupIdx * params.mNumExpertsPerGroup + laneIdx; + // note: expertSelected implies laneIdx < params.mNumExpertsPerGroup. + // we have params.mNumExpertsPerGroup == params.mNumExperts / params.mNumExpertGroups, + // thus groupIdx <= params.mNumExpertGroups - 1 => + // groupIdx * params.mNumExpertsPerGroup <= params.mNumExperts - + // params.mNumExpertsPerGroup + // => expertIdxGroup[ii] < params.mNumExperts <= NumThreads, + // so the access is safe here + expertScoreGroup[ii] = (ii < params.mNumLimitedGroups) && + (groupIdx < params.mNumExpertGroups) && expertSelected + ? smemScoreBias[expertIdxGroup[ii]] + : invalidScoreFloat; + } + + topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup, expertIdxGroup, + /* minValue */ invalidScoreFloat, params.mTopK); + } + } else if constexpr (KernelParams::MaxNumExperts > topk::MaxNumExpertsUnit) { + // without groups, each thread just takes `MaxNumTopGroups` experts + int constexpr NumExpertWarps = + (KernelParams::MaxNumExperts - 1) / topk::MaxNumExpertsUnit + 1; + int constexpr NumInterTopK = NumExpertWarps * KernelParams::MaxNumTopExperts; + __shared__ float __attribute((aligned(128))) smemInterTopScores[NumInterTopK]; + __shared__ int32_t __attribute((aligned(128))) smemInterTopExperts[NumInterTopK]; + if (warpIdx < NumExpertWarps) { + int offset = warpIdx * WarpSize * MaxNumTopGroups; +#pragma unroll + for (int ii = 0; ii < MaxNumTopGroups; ++ii) { + auto expertIdx = ii * WarpSize + laneIdx; + expertIdxGroup[ii] = offset + expertIdx; + expertScoreGroup[ii] = offset + expertIdx < params.mNumExperts + ? smemScoreBias[offset + expertIdx] + : invalidScoreFloat; + } + topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup, expertIdxGroup, + /* minValue */ invalidScoreFloat, params.mTopK); + + if (laneIdx < params.mTopK) { + smemInterTopScores[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] = + topScores[laneIdx]; + smemInterTopExperts[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] = + topExperts[laneIdx]; + } else if (laneIdx >= params.mTopK && laneIdx < KernelParams::MaxNumTopExperts) { + smemInterTopScores[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] = + invalidScoreFloat; + smemInterTopExperts[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] = + MaxSupportedExpertCount - 1; + } + } + __syncthreads(); + if (warpIdx == 0) { + int constexpr NumInterTopKPerThread = (NumInterTopK - 1) / WarpSize + 1; + float intermediateScore[NumInterTopKPerThread]; + int32_t intermediateExpert[NumInterTopKPerThread]; + for (int i = laneIdx; i < NumInterTopKPerThread * WarpSize; i += WarpSize) { + int ii = i / WarpSize; + if (i < NumInterTopK) { + intermediateScore[ii] = smemInterTopScores[i]; + intermediateExpert[ii] = smemInterTopExperts[i]; + } else { + intermediateScore[ii] = invalidScoreFloat; + intermediateExpert[ii] = KernelParams::MaxNumExperts - 1; + } + } + topk::reduceTopK(warp, topScores, topExperts, intermediateScore, intermediateExpert, + /* minValue */ invalidScoreFloat, params.mTopK); + } + } else { + if (warpIdx == 0) { + // without groups, each thread just takes `MaxNumTopGroups` experts +#pragma unroll + for (int ii = 0; ii < MaxNumTopGroups; ++ii) { + auto expertIdx = ii * WarpSize + laneIdx; + expertIdxGroup[ii] = expertIdx; + expertScoreGroup[ii] = + expertIdx < params.mNumExperts ? smemScoreBias[expertIdx] : invalidScoreFloat; + } + topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup, expertIdxGroup, + /* minValue */ invalidScoreFloat, params.mTopK); + } + } + + if (warpIdx == 0) { + // determine our lane's expert index and write to output + int32_t expertIdx = 0; +#pragma unroll + for (int ii = 0; ii < params.mTopK; ++ii) { // bound of params.mTopK + expertIdx = laneIdx == ii ? topExperts[ii] : expertIdx; + } + // determine whether our expert is local to this GPU + auto localExpertIdx = expertIdx - params.mLocalExpertsStartIdx; + auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent && + (localExpertIdx & ((1 << params.mLocalExpertsStrideLog2) - 1)) == 0; + + float scoreNorm = laneIdx < params.mTopK ? smemScoreSigmoid[expertIdx] : 0.F; + auto redNorm = cg::reduce(warp, scoreNorm, cg::plus{}); + auto finalScore = OutputT{scoreNorm * params.mRouteScale / redNorm}; + + // write expert idx out already + auto idxTopK = blockIdx.x * params.mTopK + laneIdx; + if (laneIdx < params.mTopK && params.mPtrTopKPacked != nullptr) { + PackedScoreIdx packedScore{static_cast(finalScore), + static_cast(expertIdx)}; + params.mPtrTopKPacked[idxTopK] = packedScore; + } + + if (laneIdx < params.mTopK && params.mPtrTopKWeights != nullptr && + params.mPtrTopKIds == nullptr) { + params.mPtrTopKWeights[idxTopK] = finalScore; + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void launchMainKernel(Data& data, int numBlocks, int numThreadsMain, void* stream) { + LAUNCH_ROUTING_DEEPSEEK(data, + /*coopLaunch=*/false, routingMainKernel, numBlocks, numThreadsMain, + /*smemSize=*/0, // No dynamic smem + stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace routingDeepSeek +} // namespace moe::dev::routing diff --git a/csrc/fused_moe/trtllm_backend/routingDeepSeek/launchOffsetsKernel.cu b/csrc/fused_moe/trtllm_backend/routingDeepSeek/launchOffsetsKernel.cu new file mode 100644 index 0000000000..2799bd7542 --- /dev/null +++ b/csrc/fused_moe/trtllm_backend/routingDeepSeek/launchOffsetsKernel.cu @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "RoutingDeepSeekCommon.cuh" + +namespace moe::dev::routing { +namespace routingDeepSeek { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void launchOffsetsKernel(Data& data, int numBlocksOffsets, int numThreadsHist, void* stream) { + LAUNCH_ROUTING_DEEPSEEK(data, + /*coopLaunch=*/false, routingIndicesOffsetsKernel, numBlocksOffsets, + numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream, data.mNumExpertGroups > 1, /*forceFloatInput=*/true); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace routingDeepSeek +} // namespace moe::dev::routing diff --git a/csrc/fused_moe/trtllm_backend/routingRenormalize/RoutingRenormalizeCommon.cuh b/csrc/fused_moe/trtllm_backend/routingRenormalize/RoutingRenormalizeCommon.cuh new file mode 100644 index 0000000000..0308e44d11 --- /dev/null +++ b/csrc/fused_moe/trtllm_backend/routingRenormalize/RoutingRenormalizeCommon.cuh @@ -0,0 +1,131 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "flashinfer/trtllm/fused_moe/RoutingKernel.cuh" + +namespace moe::dev::routing { +namespace routingRenormalize { +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static constexpr int NumExperts128Experts = 128; +static constexpr int NumExperts512Experts = 512; +static constexpr int MaxSupportedExperts = 2048; + +static constexpr int NumTop8Experts = 8; +static constexpr int NumTop16Experts = 16; +static constexpr int MaxSupportedTopExperts = 32; + +static constexpr int NumThreads = 1024; +static constexpr int NumWarps = NumThreads / WarpSize; + +static constexpr int MaxNumTokensSingleCluster = NumBlocksPerCluster * NumThreads; +static constexpr int MaxNumTokensSingleClusterScores = NumBlocksPerCluster * NumWarps; + +static constexpr int BlockKernelMaxNumTokens = 4; + +template +__forceinline__ __device__ void routingTopKExperts( + cg::thread_block_tile const& warp, DataType (&score)[VecSize], + int32_t (&idx)[VecSize], DataType (&warpTopKScore)[K], int32_t (&warpTopKExpertIdx)[K], + int32_t const laneIdx, int32_t const numExperts, int32_t topK, InputType const* ptrScores, + bool const normTopkProb, bool const applySoftmaxAfterTopK = true) { + DataType minScore = DataType{-INFINITY}; + + for (int i = 0; i < VecSize; i++) { + auto expertIdx = i * WarpSize + laneIdx; + auto newScore = expertIdx < numExperts ? static_cast(ptrScores[expertIdx]) : minScore; + score[i] = newScore; + idx[i] = expertIdx; + } + if constexpr (DoSoftmaxBeforeTopK) { + calcSoftmax(warp, score); + } + + // Get the top-k scores and their corresponding expert indices + topk::reduceTopK(warp, warpTopKScore, warpTopKExpertIdx, score, idx, minScore, topK); + + // Normalize the scores + if constexpr (DoSoftmaxBeforeTopK) { + float sum = float{1.f}; + if (normTopkProb) { + sum = static_cast(laneIdx < topK ? warpTopKScore[laneIdx] : 0); + sum = cg::reduce(warp, sum, cg::plus()); + } + if (laneIdx < topK) { + warpTopKScore[laneIdx] = warpTopKScore[laneIdx] / sum; + } + } else { + if (applySoftmaxAfterTopK) { + auto softmaxScore = + calcSoftmax(warp, laneIdx < topK ? warpTopKScore[laneIdx] : minScore, laneIdx, topK); + if (laneIdx < topK) { + warpTopKScore[laneIdx] = softmaxScore; + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +int32_t constexpr getMaxNumExperts(int32_t numExperts) { + if (numExperts <= NumExperts128Experts) { + return NumExperts128Experts; + } else if (numExperts <= NumExperts512Experts) { + return NumExperts512Experts; + } else if (numExperts <= MaxSupportedExperts) { + return MaxSupportedExperts; + } else { + TLLM_LOG_ERROR("Unsupported numExperts"); + return 0; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Helper macro: dispatch on topK tier for a given numExperts tier. +#define LAUNCH_ROUTING_WITH_TOPK(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + stream, extraFlag1, numExperts) \ + if (data.mTopK <= NumTop8Experts) { \ + LAUNCH_ROUTING_WITH_NUM_EXPERTS(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + stream, extraFlag1, numExperts, NumTop8Experts); \ + } else if (data.mTopK <= NumTop16Experts) { \ + LAUNCH_ROUTING_WITH_NUM_EXPERTS(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + stream, extraFlag1, numExperts, NumTop16Experts); \ + } else { \ + LAUNCH_ROUTING_WITH_NUM_EXPERTS(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + stream, extraFlag1, numExperts, MaxSupportedTopExperts); \ + } + +#define LAUNCH_ROUTING_RENORMALIZE(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + stream, extraFlag1) \ + if (data.mNumExperts <= NumExperts128Experts) { \ + LAUNCH_ROUTING_WITH_TOPK(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \ + extraFlag1, NumExperts128Experts); \ + } else if (data.mNumExperts <= NumExperts512Experts) { \ + LAUNCH_ROUTING_WITH_TOPK(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \ + extraFlag1, NumExperts512Experts); \ + } else if (data.mNumExperts <= MaxSupportedExperts) { \ + LAUNCH_ROUTING_WITH_TOPK(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \ + extraFlag1, MaxSupportedExperts); \ + } else { \ + TLLM_LOG_ERROR("Unsupported numExperts"); \ + } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace routingRenormalize +} // namespace moe::dev::routing diff --git a/csrc/fused_moe/trtllm_backend/routingRenormalize/launchBlockKernel.cu b/csrc/fused_moe/trtllm_backend/routingRenormalize/launchBlockKernel.cu new file mode 100644 index 0000000000..ac9b62c445 --- /dev/null +++ b/csrc/fused_moe/trtllm_backend/routingRenormalize/launchBlockKernel.cu @@ -0,0 +1,267 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "RoutingRenormalizeCommon.cuh" + +namespace moe::dev::routing { +namespace routingRenormalize { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelParams::MaxNumExperts + : 1024) + routingIndicesBlockKernel(KernelParams params) { + // types used in this kernel + using OutputT = typename KernelParams::OutputT; + using InputT = typename KernelParams::InputT; + using BaseType = std::conditional_t; + using TypePacked = PackedScoreIdx; + static constexpr int MaxNumExperts = KernelParams::MaxNumExperts; + // When MaxNumExperts > 1024, cap actual thread count at 1024 and let each thread handle + // multiple experts. This is needed because CUDA blocks support at most 1024 threads. + static constexpr int NumThreadsBlock = MaxNumExperts <= 1024 ? MaxNumExperts : 1024; + static constexpr int ExpertsPerThread = MaxNumExperts / NumThreadsBlock; + static_assert(MaxNumExperts % NumThreadsBlock == 0, + "MaxNumExperts must be a multiple of NumThreadsBlock"); + + int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); + int32_t const laneIdx = cutlass::arch::LaneId(); + auto scoreOffset = warpIdx * params.mNumExperts; + bool validToken = warpIdx < params.mNumTokens; + + static constexpr int VecSize = KernelParams::MaxNumExperts / WarpSize; + static constexpr int totalExpertCounts = BlockKernelMaxNumTokens * MaxNumExperts; + __shared__ int8_t __attribute((aligned(128))) smemOffset[totalExpertCounts]; + __shared__ int8_t __attribute((aligned(128))) smemKIdx[totalExpertCounts]; + + using Scan = cub::BlockScan; + __shared__ typename Scan::TempStorage tempStorage; + + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + for (int i = threadIdx.x; i < totalExpertCounts; i += blockDim.x) { + smemOffset[i] = int8_t{-1}; + smemKIdx[i] = int8_t{-1}; + } + __syncthreads(); + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + // then wait on primary grid + if constexpr (KernelParams::UsePdl) { + cudaGridDependencySynchronize(); + } +#endif + + if (params.mPtrTopKIds != nullptr) { + if (validToken) { + if (laneIdx < params.mTopK) { + auto expertIdx = params.mPtrTopKIds[warpIdx * params.mTopK + laneIdx]; + if (expertIdx != -1) { + int offset = warpIdx * MaxNumExperts + expertIdx; + smemKIdx[offset] = static_cast(laneIdx); + } else { + params.mPtrExpandedIdxToPermutedIdx[warpIdx * params.mTopK + laneIdx] = int32_t{-1}; + } + } + } + } else if (params.mPtrScores != nullptr) { + // in this case, each warp represents a token + BaseType score[VecSize]; + int32_t idx[VecSize]; + + BaseType warpTopKScore[KernelParams::MaxNumTopExperts]; + int32_t warpTopKExpertIdx[KernelParams::MaxNumTopExperts]; + + BaseType minScore = BaseType{-INFINITY}; + if (validToken) { + routingTopKExperts( + warp, score, idx, warpTopKScore, warpTopKExpertIdx, laneIdx, params.mNumExperts, + params.mTopK, params.mPtrScores + scoreOffset, params.mNormTopkProb, + params.mApplySoftmaxAfterTopK); + + if (laneIdx < params.mTopK) { + int offset = warpIdx * MaxNumExperts + warpTopKExpertIdx[laneIdx]; + smemKIdx[offset] = static_cast(laneIdx); + if (params.mPtrTopKWeights != nullptr) { + params.mPtrTopKWeights[warpIdx * params.mTopK + laneIdx] = + OutputT{warpTopKScore[laneIdx]}; + } + } + } // end if (validToken) + } else if (params.mPtrTopKPacked != nullptr) { + if (validToken) { + if (laneIdx < params.mTopK) { + int offset = warpIdx * MaxNumExperts + + static_cast(params.mPtrTopKPacked[warpIdx * params.mTopK + laneIdx].idx); + smemKIdx[offset] = static_cast(laneIdx); + if (params.mPtrTopKWeights != nullptr) { + params.mPtrTopKWeights[warpIdx * params.mTopK + laneIdx] = + static_cast(params.mPtrTopKPacked[warpIdx * params.mTopK + laneIdx].score); + } + } + } + } + __syncthreads(); + + // Each thread handles ExpertsPerThread contiguous experts. + // Thread i handles experts [i * ExpertsPerThread, (i+1) * ExpertsPerThread). + // Contiguous assignment ensures prefix sum ordering is correct. + int accExpertCount[ExpertsPerThread]; +#pragma unroll + for (int e = 0; e < ExpertsPerThread; e++) { + int expert = threadIdx.x * ExpertsPerThread + e; + auto localExpIdx = expert - params.mLocalExpertsStartIdx; + auto isLocal = localExpIdx >= 0 && localExpIdx < params.mNumLocalExperts && + (localExpIdx & ((1 << params.mLocalExpertsStrideLog2) - 1)) == 0; + + // Get the count of each expert and the offset for each token + accExpertCount[e] = 0; + if (isLocal) { + int offset = expert; + for (int j = 0; j < BlockKernelMaxNumTokens; j++) { + if (smemKIdx[offset] >= 0) { + smemOffset[offset] = static_cast(accExpertCount[e]); + accExpertCount[e]++; + } + offset += MaxNumExperts; + } + } + } + __syncthreads(); + + // Get the number of CTAs and the offset for each CTA. + // Use cub::BlockScan's array overload: each thread holds ExpertsPerThread items, + // and ExclusiveSum computes the prefix sum across all NumThreadsBlock * ExpertsPerThread + // items in thread order — exactly matching our contiguous expert assignment. + int32_t numCtaPerExpert[ExpertsPerThread]; +#pragma unroll + for (int e = 0; e < ExpertsPerThread; e++) { + if constexpr (KernelParams::isPow2) { + numCtaPerExpert[e] = divUpLog2(accExpertCount[e], params.mPaddingLog2); + } else { + numCtaPerExpert[e] = divUpTileN(accExpertCount[e], params.mTileTokensDim); + } + } + int32_t ctaOffsetPerExpert[ExpertsPerThread]; + int32_t numNonExitingCtas; + Scan(tempStorage).ExclusiveSum(numCtaPerExpert, ctaOffsetPerExpert, numNonExitingCtas); + __syncthreads(); // Required barrier before reusing TempStorage for the next BlockScan + + // Compute padded expert scan counts (same array-overload pattern) + int32_t tmpCountPerExpert[ExpertsPerThread]; +#pragma unroll + for (int e = 0; e < ExpertsPerThread; e++) { + if constexpr (KernelParams::isPow2) { + tmpCountPerExpert[e] = divUpMulLog2(accExpertCount[e], params.mPaddingLog2); + } else { + tmpCountPerExpert[e] = divUpMulTileN(accExpertCount[e], params.mTileTokensDim); + } + } + int32_t expertScanCountsPerExpert[ExpertsPerThread]; + Scan(tempStorage).ExclusiveSum(tmpCountPerExpert, expertScanCountsPerExpert); + __syncthreads(); + + // Write CTA configs for each expert this thread handles +#pragma unroll + for (int e = 0; e < ExpertsPerThread; e++) { + int expert = threadIdx.x * ExpertsPerThread + e; + auto localExpIdx = expert - params.mLocalExpertsStartIdx; + auto isLocal = localExpIdx >= 0 && localExpIdx < params.mNumLocalExperts && + (localExpIdx & ((1 << params.mLocalExpertsStrideLog2) - 1)) == 0; + + if (isLocal) { + for (int cta = 0; cta < numCtaPerExpert[e]; ++cta) { + int32_t const mappedLocalIdx = + (expert - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2; + params.mPtrCtaIdxXyToBatchIdx[ctaOffsetPerExpert[e] + cta] = mappedLocalIdx; + int32_t mnLimit1; + int32_t mnLimit2; + if constexpr (KernelParams::isPow2) { + mnLimit1 = mulLog2(ctaOffsetPerExpert[e] + cta + 1, params.mPaddingLog2); + mnLimit2 = + mulLog2(ctaOffsetPerExpert[e], params.mPaddingLog2) + accExpertCount[e]; + } else { + mnLimit1 = mulTileN(ctaOffsetPerExpert[e] + cta + 1, params.mTileTokensDim); + mnLimit2 = + mulTileN(ctaOffsetPerExpert[e], params.mTileTokensDim) + accExpertCount[e]; + } + params.mPtrCtaIdxXyToMnLimit[ctaOffsetPerExpert[e] + cta] = min(mnLimit1, mnLimit2); + } + } + } + + // at this point, we can write out padded count + if (threadIdx.x == 0) { + int32_t permutedIdxSize; + if constexpr (KernelParams::isPow2) { + permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); + } else { + permutedIdxSize = mulTileN(numNonExitingCtas, params.mTileTokensDim); + } + params.mPtrPermutedIdxSize[0] = permutedIdxSize; + params.mPtrNumNonExitingCtas[0] = numNonExitingCtas; + } + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + // we can trigger the next kernel at this point + if constexpr (KernelParams::UsePdl) { + cudaTriggerProgrammaticLaunchCompletion(); + } +#endif + + for (int tokenIdx = 0; tokenIdx < params.mNumTokens; tokenIdx++) { +#pragma unroll + for (int e = 0; e < ExpertsPerThread; e++) { + int expert = threadIdx.x * ExpertsPerThread + e; + int offset = tokenIdx * MaxNumExperts + expert; + if (smemKIdx[offset] >= 0) { + auto localExpIdx = expert - params.mLocalExpertsStartIdx; + auto isLocal = localExpIdx >= 0 && localExpIdx < params.mNumLocalExperts && + (localExpIdx & ((1 << params.mLocalExpertsStrideLog2) - 1)) == 0; + + int const expandedIdx = tokenIdx * params.mTopK + smemKIdx[offset]; + int const offsetWithinExpert = static_cast(smemOffset[offset]); + int const offsetForExpert = expertScanCountsPerExpert[e]; + int const permutedIdx = isLocal ? offsetForExpert + offsetWithinExpert : int32_t{-1}; + + if (params.mPtrExpandedIdxToPermutedIdx != nullptr) { + params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx; + } + if (params.mPtrPermutedIdxToExpandedIdx != nullptr && isLocal) { + params.mPtrPermutedIdxToExpandedIdx[permutedIdx] = expandedIdx; + } + if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocal) { + params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx; + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void launchBlockKernel(Data const& data, uint32_t numThreadsHist, void* stream) { + LAUNCH_ROUTING_RENORMALIZE(data, false, routingIndicesBlockKernel, 1, numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream, data.mDoSoftmaxBeforeTopK); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace routingRenormalize +} // namespace moe::dev::routing diff --git a/csrc/fused_moe/trtllm_backend/routingRenormalize/launchClusterKernel.cu b/csrc/fused_moe/trtllm_backend/routingRenormalize/launchClusterKernel.cu new file mode 100644 index 0000000000..28cfd29584 --- /dev/null +++ b/csrc/fused_moe/trtllm_backend/routingRenormalize/launchClusterKernel.cu @@ -0,0 +1,112 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "RoutingRenormalizeCommon.cuh" + +namespace moe::dev::routing { +namespace routingRenormalize { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) +__global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(NumThreads) + routingIndicesClusterKernel(KernelParams params) { + // number of tokens/expanded idx is bounded by total number of warps + using OutputT = typename KernelParams::OutputT; + using InputT = typename KernelParams::InputT; + + using BaseType = std::conditional_t; + using TypePacked = PackedScoreIdx; + + static constexpr int VecSize = KernelParams::MaxNumExperts / WarpSize; + + __shared__ TypePacked + __attribute((aligned(128))) smemPackedScoreIdx[NumWarps * KernelParams::MaxNumTopExperts]; + + uint32_t const clusterBlockRank = blockIdx.x; + + int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); + int32_t const laneIdx = cutlass::arch::LaneId(); + + auto warpTokenIdx = clusterBlockRank * NumWarps + warpIdx; + auto scoreOffset = warpTokenIdx * params.mNumExperts; + bool validToken = warpTokenIdx < params.mNumTokens; + + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + + // then wait on primary grid + if constexpr (KernelParams::UsePdl) { + cudaGridDependencySynchronize(); + } + + if (params.mPtrScores != nullptr) { + // in this case, each warp represents a token + BaseType score[VecSize]; + int32_t idx[VecSize]; + + BaseType warpTopKScore[KernelParams::MaxNumTopExperts]; + int32_t warpTopKExpertIdx[KernelParams::MaxNumTopExperts]; + + BaseType minScore = BaseType{-INFINITY}; + if (validToken) { + routingTopKExperts( + warp, score, idx, warpTopKScore, warpTopKExpertIdx, laneIdx, params.mNumExperts, + params.mTopK, params.mPtrScores + scoreOffset, params.mNormTopkProb, + params.mApplySoftmaxAfterTopK); + + if (laneIdx < params.mTopK) { + smemPackedScoreIdx[warpIdx * params.mTopK + laneIdx] = + TypePacked{warpTopKScore[laneIdx], static_cast(warpTopKExpertIdx[laneIdx])}; + } + } // end if (validToken) + } + + // make packed scores available to all threads in cluster + __cluster_barrier_arrive(); + __cluster_barrier_wait(); + + if (params.mPtrScores != nullptr) { + routingPermutation(params, smemPackedScoreIdx, warpIdx, + clusterBlockRank); + } else { + routingPermutation(params, smemPackedScoreIdx, warpIdx, + clusterBlockRank); + } +} +#else +__global__ void __launch_bounds__(NumThreads) + routingIndicesClusterKernel(KernelParams /* params */) { + assert(false && "routingIndicesClusterKernel is only supported on SM90+ architectures"); +} +#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void launchClusterKernel(Data const& data, void* stream) { + LAUNCH_ROUTING_RENORMALIZE(data, false, routingIndicesClusterKernel, NumBlocksPerCluster, + NumThreads, + /*smemSize=*/0, // No dynamic smem + stream, data.mDoSoftmaxBeforeTopK); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace routingRenormalize +} // namespace moe::dev::routing diff --git a/csrc/fused_moe/trtllm_backend/routingRenormalize/launchHistogramKernel.cu b/csrc/fused_moe/trtllm_backend/routingRenormalize/launchHistogramKernel.cu new file mode 100644 index 0000000000..2046e4a1e0 --- /dev/null +++ b/csrc/fused_moe/trtllm_backend/routingRenormalize/launchHistogramKernel.cu @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "RoutingRenormalizeCommon.cuh" + +namespace moe::dev::routing { +namespace routingRenormalize { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void launchHistogramKernel(Data const& data, int numBlocksHistogram, uint32_t numThreadsHist, + void* stream) { + LAUNCH_ROUTING_RENORMALIZE(data, false, routingIndicesHistogramKernel, numBlocksHistogram, + numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream, data.mDoSoftmaxBeforeTopK); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace routingRenormalize +} // namespace moe::dev::routing diff --git a/csrc/fused_moe/trtllm_backend/routingRenormalize/launchHistogramScoresKernel.cu b/csrc/fused_moe/trtllm_backend/routingRenormalize/launchHistogramScoresKernel.cu new file mode 100644 index 0000000000..4da8e545a6 --- /dev/null +++ b/csrc/fused_moe/trtllm_backend/routingRenormalize/launchHistogramScoresKernel.cu @@ -0,0 +1,104 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "RoutingRenormalizeCommon.cuh" + +namespace moe::dev::routing { +namespace routingRenormalize { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// this kernel is needed in case we have scores as input for the histogram kernel +template +__global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelParams::MaxNumExperts + : 1024) + routingIndicesHistogramScoresKernel(KernelParams params) { + using OutputT = typename KernelParams::OutputT; + using InputT = typename KernelParams::InputT; + using BaseType = std::conditional_t; + // Cap actual thread count at 1024 when MaxNumExperts > 1024. + static constexpr int NumThreadsBlock = + KernelParams::MaxNumExperts <= 1024 ? KernelParams::MaxNumExperts : 1024; + + // VecSize stays based on MaxNumExperts — each warp still processes all experts for one token. + static constexpr int VecSize = KernelParams::MaxNumExperts / WarpSize; + + int32_t const laneIdx = cutlass::arch::LaneId(); + int32_t const warpIdx = threadIdx.x / WarpSize; + // Use NumThreadsBlock (actual thread count) for grid-stride warp/thread addressing + int32_t const globalWarpIdx = blockIdx.x * NumThreadsBlock / WarpSize + warpIdx; + int32_t const globalWarpStride = gridDim.x * NumThreadsBlock / WarpSize; + BaseType minScore = BaseType{-INFINITY}; + auto block = cg::this_thread_block(); + auto warp = cg::tiled_partition(block); + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + // Wait on primary grid. + if constexpr (KernelParams::UsePdl) { + cudaGridDependencySynchronize(); + } +#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + + // initialize the mPtrExpertCounts — use NumThreadsBlock for grid-stride + int32_t expertCountsNum = 2 * params.mNumExperts; + int32_t globalThreadIdx = blockIdx.x * NumThreadsBlock + threadIdx.x; + int32_t globalThreadStride = gridDim.x * NumThreadsBlock; + initArr(globalThreadIdx, expertCountsNum, globalThreadStride, params.mPtrExpertCounts, 0); + + // in this case, each warp represents a token, and we use a grid-stride loop + // over all warps/tokens + BaseType allScores[VecSize]; + int32_t allExpertIdx[VecSize]; + BaseType warpTopKScore[KernelParams::MaxNumTopExperts]; + int32_t warpTopKExpertIdx[KernelParams::MaxNumTopExperts]; + for (int tokenIdx = globalWarpIdx; tokenIdx < params.mNumTokens; tokenIdx += globalWarpStride) { + auto scoreOffset = tokenIdx * params.mNumExperts; + + routingTopKExperts( + warp, allScores, allExpertIdx, warpTopKScore, warpTopKExpertIdx, laneIdx, + params.mNumExperts, params.mTopK, params.mPtrScores + scoreOffset, params.mNormTopkProb, + params.mApplySoftmaxAfterTopK); + + if (laneIdx < params.mTopK) { + PackedScoreIdx packedScore{static_cast(warpTopKScore[laneIdx]), + static_cast(warpTopKExpertIdx[laneIdx])}; + params.mPtrTopKPacked[tokenIdx * params.mTopK + laneIdx] = packedScore; + } + } + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + // Trigger secondary kernel AFTER writing all packed scores, so the next kernel + // (routingIndicesHistogramKernel) sees the completed mPtrTopKPacked writes. + if constexpr (KernelParams::UsePdl) { + cudaTriggerProgrammaticLaunchCompletion(); + } +#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void launchHistogramScoresKernel(Data const& data, uint32_t maxNumBlocks, uint32_t numThreadsHist, + void* stream) { + LAUNCH_ROUTING_RENORMALIZE(data, false, routingIndicesHistogramScoresKernel, maxNumBlocks, + numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream, data.mDoSoftmaxBeforeTopK); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace routingRenormalize +} // namespace moe::dev::routing diff --git a/csrc/fused_moe/trtllm_backend/routingRenormalize/launchInitExpertCounts.cu b/csrc/fused_moe/trtllm_backend/routingRenormalize/launchInitExpertCounts.cu new file mode 100644 index 0000000000..5344e85365 --- /dev/null +++ b/csrc/fused_moe/trtllm_backend/routingRenormalize/launchInitExpertCounts.cu @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "RoutingRenormalizeCommon.cuh" + +namespace moe::dev::routing { +namespace routingRenormalize { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void launchInitExpertCounts(Data const& data, uint32_t numThreadsHist, void* stream) { + LAUNCH_ROUTING_RENORMALIZE(data, false, routingInitExpertCounts, + (2 * data.mNumExperts - 1) / numThreadsHist + 1, numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream, data.mDoSoftmaxBeforeTopK); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace routingRenormalize +} // namespace moe::dev::routing diff --git a/csrc/fused_moe/trtllm_backend/routingRenormalize/launchOffsetsKernel.cu b/csrc/fused_moe/trtllm_backend/routingRenormalize/launchOffsetsKernel.cu new file mode 100644 index 0000000000..19fbb844c8 --- /dev/null +++ b/csrc/fused_moe/trtllm_backend/routingRenormalize/launchOffsetsKernel.cu @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "RoutingRenormalizeCommon.cuh" + +namespace moe::dev::routing { +namespace routingRenormalize { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void launchOffsetsKernel(Data const& data, int numBlocksOffsets, uint32_t numThreadsHist, + void* stream) { + LAUNCH_ROUTING_RENORMALIZE(data, false, routingIndicesOffsetsKernel, numBlocksOffsets, + numThreadsHist, + /*smemSize=*/0, // No dynamic smem + stream, data.mDoSoftmaxBeforeTopK); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace routingRenormalize +} // namespace moe::dev::routing diff --git a/csrc/trtllm_fused_moe_dev_kernel.cu b/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_dev_kernel.cu similarity index 99% rename from csrc/trtllm_fused_moe_dev_kernel.cu rename to csrc/fused_moe/trtllm_backend/trtllm_fused_moe_dev_kernel.cu index 63e8aef5a4..50cabaeacc 100644 --- a/csrc/trtllm_fused_moe_dev_kernel.cu +++ b/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_dev_kernel.cu @@ -377,7 +377,7 @@ void run(Data const& data, void* stream) { DEEP_SEEK_ACTIVATION_NUM_THREADS_PER_CTA, 0, stream); } else { int const numThreads = 256; - const dim3 grid(data.innerDim / 128, data.topK, data.numTokens); + const dim3 grid(data.innerDim / 128, data.topK, std::min(8192, data.numTokens)); LAUNCH_ACTIVATION(data, activationKernel, 1, grid, numThreads, 0, stream); } @@ -540,7 +540,7 @@ void run(Data const& data, void* stream) { constexpr int VecSize = 4; int const numThreads = 128; int const numBlocksX = (data.hiddenDimSf / VecSize - 1 + numThreads) / numThreads; - int const numBlocksY = data.numTokens; + int const numBlocksY = std::min(8192, data.numTokens); dim3 numBlocks(numBlocksX, numBlocksY); #define CONVERT_FP4_SF_LAUNCH(LayoutSrc, LayoutDst) \ if (data.sfLayoutSrc == tg::SfLayout::LayoutSrc && \ @@ -615,7 +615,7 @@ __global__ void permuteKernel(KernelParams params) { void run(Data const& data, void* stream) { int const numThreads = 256; int const numBlocksX = (data.hiddenDim - 1 + numThreads) / numThreads; - int const numBlocksY = data.numTokens; + int const numBlocksY = std::min(8192, data.numTokens); dim3 numBlocks(numBlocksX, numBlocksY); LAUNCH(data, permuteKernel, numBlocks, numThreads, 0, stream); diff --git a/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_deepseek.cu b/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_deepseek.cu new file mode 100644 index 0000000000..4ad38feb26 --- /dev/null +++ b/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_deepseek.cu @@ -0,0 +1,128 @@ +/* + * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "routingDeepSeek/RoutingDeepSeekCommon.cuh" + +namespace moe::dev::routing { +namespace routingDeepSeek { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Forward declarations for split-compiled launch wrappers. +void launchMainKernel(Data& data, int numBlocks, int numThreadsMain, void* stream); +void launchInitExpertCounts(Data& data, int numThreadsHist, void* stream); +void launchClusterKernel(Data& data, int numThreadsHist, void* stream); +void launchCoopKernel(Data& data, int numBlocksCoop, int numThreadsHist, void* stream); +void launchHistogramKernel(Data& data, int numBlocksHistogram, int numThreadsHist, void* stream); +void launchOffsetsKernel(Data& data, int numBlocksOffsets, int numThreadsHist, void* stream); + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void run(Data& data, void* stream) { + FLASHINFER_CHECK( + data.mPtrTopKPacked != nullptr || data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr, + "Routing kernel requires at least one input parameter"); + if (data.mPtrTopKIds != nullptr) { + FLASHINFER_CHECK(data.mPtrTopKWeights != nullptr, + "When mPtrTopKIds is provided, mPtrTopKWeights must also be provided for " + "DeepSeek routing."); + } + if (data.mPtrExpandedIdxToPermutedIdx != nullptr || + data.mPtrPermutedIdxToExpandedIdx != nullptr || data.mPtrPermutedIdxToTokenIdx != nullptr) + FLASHINFER_CHECK( + (data.mPtrTopKPacked != nullptr || data.mPtrTopKIds != nullptr) && data.mPtrPermutedIdxSize, + "If permuted index is required, `mPtrTopKPacked` or `mPtrTopKIds` is also required"); + FLASHINFER_CHECK(!data.mUseRoutingSoftmax, "Routing with softmax not implemented yet"); + int const numBlocks = data.mNumTokens; + int const numThreadsHist = getMaxNumExperts(data.mNumExperts); + + bool const useSingleCluster = data.mNumTokens <= 1024; + if (!useSingleCluster) { + FLASHINFER_CHECK(data.mPtrExpertCounts != nullptr, + "When #tokens is large, `mPtrExpertCounts` is a required input."); + } else { + data.mPtrExpertCounts = nullptr; + } + + static int const smCount = tensorrt_llm::common::getMultiProcessorCount(); + int const numBlocksCoop = smCount - 8; + + int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK; + if (data.mPtrTopKIds == nullptr) { + FLASHINFER_CHECK(data.mNumExperts >= MaxSupportedTopExperts, + "Routing kernel expects %d to be at most #experts %d", MaxSupportedTopExperts, + data.mNumExperts); + FLASHINFER_CHECK(data.mNumExperts <= MaxSupportedExpertCount, + "Routing kernel expects #experts %d <= %d", data.mNumExperts, + MaxSupportedExpertCount); + FLASHINFER_CHECK(data.mTopK <= MaxSupportedTopExperts, + "Routing kernel expects topK experts <= %d, got %d", MaxSupportedTopExperts, + data.mTopK); + + if (data.mNumExpertGroups > 1) { + FLASHINFER_CHECK(data.mNumExpertGroups <= MaxNumGroups, + "Routing kernel expects #expert groups %d to be <= max groups %d", + data.mNumExpertGroups, MaxNumGroups); + FLASHINFER_CHECK(data.mNumExperts % data.mNumExpertGroups == 0, + "Routing kernel expects #experts %d to be a multiple of #expert groups %d", + data.mNumExperts, data.mNumExpertGroups); + FLASHINFER_CHECK(data.mNumExperts / data.mNumExpertGroups <= WarpSize, + "Routing kernel expects #experts per group <= warp size, got %d", + data.mNumExperts / data.mNumExpertGroups); + FLASHINFER_CHECK(data.mNumLimitedGroups <= MaxNumTopGroups, + "Routing kernel expects <= %d top groups, got %d", MaxNumTopGroups, + data.mNumLimitedGroups); + FLASHINFER_CHECK(data.mNumExpertGroups >= data.mNumLimitedGroups, + "Routing kernel expects top groups %d to be limited by #expert groups %d", + data.mNumLimitedGroups, data.mNumExpertGroups); + FLASHINFER_CHECK(data.mNumExperts % 4 == 0, + "Routing kernel expects #experts %d to be a multiple of 4.", + data.mNumExperts); + } + + int const numThreadsMain = + max(data.mNumExpertGroups * WarpSize, getMaxNumExperts(data.mNumExperts)); + launchMainKernel(data, numBlocks, numThreadsMain, stream); + } else { + launchInitExpertCounts(data, numThreadsHist, stream); + } + + if (data.mPtrPermutedIdxSize != nullptr) { + if (useSingleCluster) { + launchClusterKernel(data, numThreadsHist, stream); + } else if (data.mNumTokens <= maxTokensCoop) { + launchCoopKernel(data, numBlocksCoop, numThreadsHist, stream); + } else { + const int32_t expandedIdxSize = data.mNumTokens * data.mTopK; + const int32_t histogramEltsPerBlock = 8 * numThreadsHist; + const int32_t offsetEltsPerBlock = NumEltsPerOffsetTilePerThread * numThreadsHist; + + const int32_t maxNumBlocks = 1024; + + int const numBlocksHistogram = std::min( + (expandedIdxSize + histogramEltsPerBlock - 1) / histogramEltsPerBlock, maxNumBlocks); + int const numBlocksOffsets = + std::min((expandedIdxSize + offsetEltsPerBlock - 1) / offsetEltsPerBlock, maxNumBlocks); + + launchHistogramKernel(data, numBlocksHistogram, numThreadsHist, stream); + launchOffsetsKernel(data, numBlocksOffsets, numThreadsHist, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace routingDeepSeek +} // namespace moe::dev::routing diff --git a/csrc/trtllm_fused_moe_routing_llama4.cu b/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_llama4.cu similarity index 75% rename from csrc/trtllm_fused_moe_routing_llama4.cu rename to csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_llama4.cu index 31674e0a8e..41507a56af 100644 --- a/csrc/trtllm_fused_moe_routing_llama4.cu +++ b/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_llama4.cu @@ -25,13 +25,10 @@ namespace routingLlama4 { static constexpr int NumThreads = 1024; static constexpr int NumWarps = NumThreads / WarpSize; static constexpr int MaxNumTopExperts = 1; -static constexpr int NumExpertsLimit = 128; +static constexpr int MaxSupportedExperts = 128; static constexpr int MaxNumTokensSingleCluster = NumBlocksPerCluster * NumThreads; static constexpr int MaxNumTokensSingleClusterScores = NumBlocksPerCluster * NumWarps; static constexpr int WarpKernelSmemStride = 33; -// with further optimization to `routingIndicesWarpKernel`, this limit may -// increase. For now, it is a good cut-off point for when the block-wise -// operations are more efficient end-to-end. static constexpr int WarpKernelMaxNumTokens = 4; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -47,11 +44,9 @@ __forceinline__ __device__ void routingTopKExperts(cg::thread_block_tile; - // Non-vectorized loading: directly access ptrScores with expertIdx for (int i = 0; i < VecSize; ++i) { auto expertIdx = i * WarpSize + laneIdx; auto newScore = expertIdx < numExperts ? ptrScores[expertIdx] : minScore; - // note: use `>=` s.t. highest index always wins, just like in `reduceTopK` if (newScore > maxScore) { maxScore = newScore; maxExpertIdx = expertIdx; @@ -65,68 +60,50 @@ __forceinline__ __device__ void routingTopKExperts(cg::thread_block_tile __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParams params) { - // types used in this kernel using OutputT = typename KernelParams::OutputT; using InputT = typename KernelParams::InputT; using TypePacked = PackedScoreIdx; - // use the default cub warp-scan, with shfl using Scan = cub::WarpScan; __shared__ typename Scan::TempStorage tempStorage; - // each thread encodes 4 experts in one `int32_t`. The assumption is that - // we don't have more than 127 tokens, but `WarpKernelMaxNumTokens` must be - // smaller than that because other approaches will be more efficient for - // 127 tokens. static constexpr int ExpertsPerThread = sizeof(int32_t); static_assert(WarpKernelMaxNumTokens <= 127); - // this is a full table of which token is routed to which expert. - // the assumption here is that there are no more than 128 experts. - // we use a stride of 33 instead of 32 to avoid shared memory bank conflicts. __shared__ int32_t __attribute(( aligned(128))) smemExpertTokenCountFull[WarpKernelMaxNumTokens][WarpKernelSmemStride]; static_assert(WarpKernelSmemStride == WarpSize + 1); static_assert(KernelParams::MaxNumExperts / sizeof(int32_t) <= WarpSize); - // values needed for the top-1 reduction, if required InputT minScore = InputT{-INFINITY}; auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); #pragma unroll for (int tokenIdx = 0; tokenIdx < WarpKernelMaxNumTokens; ++tokenIdx) { - // reset full shared memory field to 0 smemExpertTokenCountFull[tokenIdx][threadIdx.x] = 0; } __syncwarp(); #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - // then wait on primary grid if constexpr (KernelParams::UsePdl) { cudaGridDependencySynchronize(); } #endif if (params.mPtrScores != nullptr && params.mPtrTopKIds == nullptr) { - // if we use `mPtrScores` as input, we need to perform the top-1 reduction - // for each token, we load the scores then use `reduceTopK` for this. - // each thread works on 4 experts, so a local reduction is done before for (int tokenIdx = 0; tokenIdx < params.mNumTokens; ++tokenIdx) { auto scoreOffset = tokenIdx * params.mNumExperts; int32_t warpMaxExpertIdx[MaxNumTopExperts]; InputT warpMaxScore[MaxNumTopExperts]; - // Use routingTopKExperts function instead of inline logic routingTopKExperts(warp, warpMaxScore, warpMaxExpertIdx, threadIdx.x, params.mNumExperts, params.mPtrScores + scoreOffset); if (cute::elect_one_sync()) { - // one thread updates the count linking token to chosen expert auto expertTokenCount = 0; setBits(expertTokenCount, 1, warpMaxExpertIdx[0] % ExpertsPerThread); smemExpertTokenCountFull[tokenIdx][warpMaxExpertIdx[0] / ExpertsPerThread] = expertTokenCount; - // we also compute the final score here and write it out if required auto finalScore = OutputT{sigmoid_accurate(float{warpMaxScore[0]})}; if (params.mPtrTopKWeights != nullptr) { params.mPtrTopKWeights[tokenIdx] = finalScore; @@ -134,10 +111,6 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam } } } else { - // if we do not have `mPtrScores` as input, we expect that `params.mPtrTopKPacked` or - // `params.mPtrTopKIds` and `params.mPtrTopKWeights` contains the top-1 packed score and index - // already. Each thread represents a token here, and we extract the relevant score The - // assumption is that the #tokens is limited by warp-size static_assert(WarpKernelMaxNumTokens <= WarpSize); TypePacked scoreIdx = TypePacked{}; if (params.mPtrTopKIds != nullptr) { @@ -150,7 +123,6 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam scoreIdx = TypePacked{static_cast(params.mPtrTopKPacked[threadIdx.x].score), static_cast(params.mPtrTopKPacked[threadIdx.x].idx)}; if (params.mPtrTopKWeights != nullptr) { - // we also compute the final score here and write it out if required auto finalScore = OutputT{sigmoid_accurate(float{scoreIdx.score})}; params.mPtrTopKWeights[threadIdx.x] = finalScore; } @@ -164,27 +136,19 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam } } - // make the full table available to all threads __syncwarp(); - // at this point, each thread keeps a count of its 4 assigned experts in - // `expertCount`, as well as the offsets for all tokens w.r.t. these 4 experts - // in `expertOffset`. int32_t expertCount = 0; int32_t expertOffset[WarpKernelMaxNumTokens + 1]; #pragma unroll for (int tokenIdx = 0; tokenIdx < WarpKernelMaxNumTokens + 1; ++tokenIdx) { if (tokenIdx > params.mNumTokens) break; - // simple reduction for `expertCount`, and scan for `expertOffset` auto expertTokenCount = tokenIdx < params.mNumTokens ? smemExpertTokenCountFull[tokenIdx][threadIdx.x] : 0; expertOffset[tokenIdx] = expertCount; expertCount += expertTokenCount; } - // at this point, we are ready for the scan across all experts to get the - // thread-wise offsets across experts - // first, we need to reduce across our 4 experts into `numCta` int32_t numCta = 0; #pragma unroll for (int ii = 0; ii < ExpertsPerThread; ++ii) { @@ -197,13 +161,10 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam } numCta += num; } - // second, we perform the exclusive sum across the warp int32_t ctaOffset; int32_t numNonExitingCtas; Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas); - // finally, we perform a scan across our local experts, starting with the - // warp-wide scan result (`ctaOffset`) auto ctaOffsetExp = ctaOffset; #pragma unroll for (int ii = 0; ii < ExpertsPerThread; ++ii) { @@ -215,8 +176,6 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam finalNumCta = divUpTileN(count, params.mTileTokensDim); } auto expertIdx = threadIdx.x * ExpertsPerThread + ii; - // during the scan for expert offsets, we can already write out - // both `mPtrCtaIdxXyToBatchIdx` and `mPtrCtaIdxXyToMnLimit` for (int cta = 0; cta < finalNumCta; ++cta) { params.mPtrCtaIdxXyToBatchIdx[ctaOffsetExp + cta] = expertIdx; int32_t mnLimit1; @@ -233,7 +192,6 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam ctaOffsetExp += finalNumCta; } - // at this point, we can write out padded count from the warp-aggregate if (cute::elect_one_sync()) { int32_t permutedIdxSize; if constexpr (KernelParams::isPow2) { @@ -246,17 +204,11 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - // we can trigger the next kernel at this point if constexpr (KernelParams::UsePdl) { cudaTriggerProgrammaticLaunchCompletion(); } #endif - // at this point, all values for offsets are ready, except the final offsets - // within the padded index (`permutedIdx`) - // for this, we perform a scan similar to the one directly after the warp-scan: - // here, we keep the local offset for each of the thread's experts in a field - // of registers auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2; int32_t finalExpertOffset[ExpertsPerThread]; if constexpr (KernelParams::isPow2) { @@ -277,48 +229,36 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam #pragma unroll for (int tokenIdx = 0; tokenIdx < WarpKernelMaxNumTokens; ++tokenIdx) { - // at this point, we can calculate the final index: - // we simply loop over all tokens, and all experts assigned to this thread. - // For each pair, we determine whether that token was routed to that expert - // based on whether the offset for that token changed. - // we can then easily compute the final `expertIdx` and `permutedIdx` relative - // to this token and expert, and write them out. if (tokenIdx >= params.mNumTokens) break; #pragma unroll for (int ii = 0; ii < ExpertsPerThread; ++ii) { - // determine whether the offset for this expert and token changes auto localOffsetToken = getBits(expertOffset[tokenIdx], ii); auto isTokenRouted = getBits(expertOffset[tokenIdx + 1], ii) > localOffsetToken; - // the expert index of this expert auto expertIdx = threadIdx.x * ExpertsPerThread + ii; auto localExpertIdx = static_cast(expertIdx) - params.mLocalExpertsStartIdx; auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent && - (localExpertIdx & params.mLocalExpertsStrideLog2) == 0; - // the permuted index: we add the local offset relative to this expert and token - // to the global offset from the scan for this expert + (localExpertIdx & ((1 << params.mLocalExpertsStrideLog2) - 1)) == 0; auto permutedIdx = isLocalExpert ? finalExpertOffset[ii] + localOffsetToken : int32_t{-1}; - // write out `mPtrExpandedIdxToPermutedIdx` if required if (params.mPtrExpandedIdxToPermutedIdx != nullptr && isTokenRouted) { params.mPtrExpandedIdxToPermutedIdx[tokenIdx] = permutedIdx; } - // write out `mPtrPermutedIdxToExpandedIdx` if required - if (params.mPtrPermutedIdxToExpandedIdx != nullptr && isLocalExpert && isTokenRouted) { + if (params.mPtrPermutedIdxToExpandedIdx != nullptr && isLocalExpert) { params.mPtrPermutedIdxToExpandedIdx[permutedIdx] = tokenIdx; } - // write out `mPtrPermutedIdxToTokenIdx` if required if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocalExpert && isTokenRouted) { params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx; } } } } + //////////////////////////////////////////////////////////////////////////////////////////////////// + template #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(NumThreads) routingIndicesClusterKernel(KernelParams params) { - // number of tokens/expanded idx is bounded by total number of warps using OutputT = typename KernelParams::OutputT; using InputT = typename KernelParams::InputT; using TypePacked = PackedScoreIdx; @@ -328,7 +268,6 @@ __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(Nu int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); int32_t const laneIdx = cutlass::arch::LaneId(); - // TODO(mjoux): expand to more tokens (possibly) auto warpTokenIdx = clusterBlockRank * NumWarps + warpIdx; auto scoreOffset = warpTokenIdx * params.mNumExperts; bool validToken = warpTokenIdx < params.mNumTokens; @@ -337,7 +276,6 @@ __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(Nu auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); - // then wait on primary grid if constexpr (KernelParams::UsePdl) { cudaGridDependencySynchronize(); } @@ -349,9 +287,6 @@ __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(Nu smemPackedScoreIdx[warpIdx] = packedScore; } } else if (params.mPtrScores != nullptr) { - // in this case, each warp represents a token - // we then exchange all token max scores, s.t. afterwards, each thread - // represents a token InputT warpMaxScore[MaxNumTopExperts]; int32_t warpMaxExpertIdx[MaxNumTopExperts]; @@ -371,7 +306,6 @@ __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(Nu } } - // make packed scores available to all threads in cluster __cluster_barrier_arrive(); __cluster_barrier_wait(); @@ -393,7 +327,6 @@ __global__ void routingIndicesClusterKernel(KernelParams params) { //////////////////////////////////////////////////////////////////////////////////////////////////// -// this kernel is needed in case we have scores as input for the histogram kernel template __global__ void __launch_bounds__(KernelParams::MaxNumExperts) routingIndicesHistogramScoresKernel(KernelParams params) { @@ -401,7 +334,6 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) using InputT = typename KernelParams::InputT; using TypePacked = PackedScoreIdx; static constexpr int VecSize = KernelParams::MaxNumExperts / WarpSize; - // we assume that #experts is a multiple of 4, so VecSize must be 4. static_assert(VecSize == 4); int32_t const laneIdx = cutlass::arch::LaneId(); @@ -412,22 +344,18 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) auto block = cg::this_thread_block(); auto warp = cg::tiled_partition(block); - // initialize the mPtrExpertCounts int32_t expertCountsNum = 2 * params.mNumExperts; int32_t globalThreadIdx = blockIdx.x * KernelParams::MaxNumExperts + threadIdx.x; int32_t globalThreadStride = gridDim.x * KernelParams::MaxNumExperts; initArr(globalThreadIdx, expertCountsNum, globalThreadStride, params.mPtrExpertCounts, 0); #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - // Wait on primary grid and trigger secondary kernel. if constexpr (KernelParams::UsePdl) { cudaGridDependencySynchronize(); cudaTriggerProgrammaticLaunchCompletion(); } #endif - // in this case, each warp represents a token, and we use a grid-stride loop - // over all warps/tokens for (int tokenIdx = globalWarpIdx; tokenIdx < params.mNumTokens; tokenIdx += globalWarpStride) { auto scoreOffset = tokenIdx * params.mNumExperts; int32_t warpMaxExpertIdx[MaxNumTopExperts]; @@ -458,6 +386,7 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) } //////////////////////////////////////////////////////////////////////////////////////////////////// + int constexpr getMaxNumExperts(int32_t numExperts) { if (numExperts <= topk::MaxNumExpertsUnit) { return topk::MaxNumExpertsUnit; @@ -469,7 +398,7 @@ int constexpr getMaxNumExperts(int32_t numExperts) { //////////////////////////////////////////////////////////////////////////////////////////////////// -void runImpl(Data const& data, void* stream) { +void run(Data const& data, void* stream) { FLASHINFER_CHECK( data.mPtrTopKPacked != nullptr || data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr, "Routing kernel requires at least one input parameter"); @@ -485,9 +414,9 @@ void runImpl(Data const& data, void* stream) { FLASHINFER_CHECK(data.mTopK <= MaxNumTopExperts, "Routing kernel expects topK experts <= %d, got %d", MaxNumTopExperts, data.mTopK); - FLASHINFER_CHECK(data.mNumExperts <= NumExpertsLimit, + FLASHINFER_CHECK(data.mNumExperts <= MaxSupportedExperts, "Routing kernel expects #experts %d to be no more than %d", data.mNumExperts, - NumExpertsLimit); + MaxSupportedExperts); FLASHINFER_CHECK(data.mNumExperts % 4 == 0, "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts); @@ -510,21 +439,16 @@ void runImpl(Data const& data, void* stream) { if (useSingleWarp) { LAUNCH_ROUTING_LLAMA4(data, /*coopLaunch=*/false, routingIndicesWarpKernel, 1, WarpSize, - /*smemSize=*/0, // No dynamic smem - stream); + /*smemSize=*/0, stream); } else if (useSingleCluster) { LAUNCH_ROUTING_LLAMA4(data, /*coopLaunch=*/false, routingIndicesClusterKernel, NumBlocksPerCluster, NumThreads, - /*smemSize=*/0, // No dynamic smem - stream); + /*smemSize=*/0, stream); } else { const uint32_t expandedIdxSize = data.mNumTokens * data.mTopK; - const uint32_t histogramEltsPerBlock = 8 * numThreadsHist; const uint32_t offsetEltsPerBlock = NumEltsPerOffsetTilePerThread * numThreadsHist; - - // Limit grid size (all kernels use a grid-stride loop). const uint32_t maxNumBlocks = 1024; int const numBlocksHistogram = std::min( @@ -536,45 +460,23 @@ void runImpl(Data const& data, void* stream) { LAUNCH_ROUTING_LLAMA4(data, /*coopLaunch=*/false, routingIndicesHistogramScoresKernel, maxNumBlocks, numThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream); + /*smemSize=*/0, stream); } else { - // Reset the global histograms. LAUNCH_ROUTING_LLAMA4(data, false, routingInitExpertCounts, (2 * data.mNumExperts - 1) / numThreadsHist + 1, numThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream); + /*smemSize=*/0, stream); } LAUNCH_ROUTING_LLAMA4(data, /*coopLaunch=*/false, routingIndicesHistogramKernel, numBlocksHistogram, numThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream); + /*smemSize=*/0, stream); LAUNCH_ROUTING_LLAMA4(data, /*coopLaunch=*/false, routingIndicesOffsetsKernel, numBlocksOffsets, numThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream); + /*smemSize=*/0, stream); } } -void run(Data const& data, void* stream) { - FLASHINFER_CHECK( - data.mPtrTopKPacked != nullptr || data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr, - "Routing kernel requires at least one input parameter"); - FLASHINFER_CHECK( - data.mPtrPermutedIdxSize != nullptr && data.mPtrCtaIdxXyToBatchIdx != nullptr && - data.mPtrCtaIdxXyToMnLimit != nullptr && data.mPtrNumNonExitingCtas != nullptr, - "Llama4 routing kernel expects permuted idx and grouped Gemm launch config buffers"); - FLASHINFER_CHECK(data.mTopK <= MaxNumTopExperts, - "Routing kernel expects topK experts <= ", MaxNumTopExperts, ", got ", - data.mTopK); - FLASHINFER_CHECK(data.mPaddingLog2 < 8, "Routing kernel expects padding log2 < 8, got ", - data.mPaddingLog2); - - runImpl(data, stream); -} - //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace routingLlama4 diff --git a/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_renormalize.cu b/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_renormalize.cu new file mode 100644 index 0000000000..244d84473b --- /dev/null +++ b/csrc/fused_moe/trtllm_backend/trtllm_fused_moe_routing_renormalize.cu @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "routingRenormalize/RoutingRenormalizeCommon.cuh" + +namespace moe::dev::routing { +namespace routingRenormalize { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Forward declarations of per-kernel launch wrappers (defined in routingRenormalize/*.cu). +void launchBlockKernel(Data const& data, uint32_t numThreadsHist, void* stream); +void launchClusterKernel(Data const& data, void* stream); +void launchHistogramScoresKernel(Data const& data, uint32_t maxNumBlocks, uint32_t numThreadsHist, + void* stream); +void launchInitExpertCounts(Data const& data, uint32_t numThreadsHist, void* stream); +void launchHistogramKernel(Data const& data, int numBlocksHistogram, uint32_t numThreadsHist, + void* stream); +void launchOffsetsKernel(Data const& data, int numBlocksOffsets, uint32_t numThreadsHist, + void* stream); + +//////////////////////////////////////////////////////////////////////////////////////////////////// +void run(Data const& data, void* stream) { + FLASHINFER_CHECK( + data.mPtrTopKPacked != nullptr || data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr, + "Routing kernel requires at least one input parameter"); + if (data.mPtrTopKIds != nullptr) { + FLASHINFER_CHECK(data.mPtrTopKWeights != nullptr, + "When mPtrTopKIds is provided, mPtrTopKWeights must also be provided for " + "Renormalize routing."); + } + FLASHINFER_CHECK( + data.mPtrPermutedIdxSize != nullptr && data.mPtrCtaIdxXyToBatchIdx != nullptr && + data.mPtrCtaIdxXyToMnLimit != nullptr && data.mPtrNumNonExitingCtas != nullptr, + "Renormalize routing kernel expects permuted idx and grouped Gemm launch config buffers"); + FLASHINFER_CHECK(data.mTopK <= MaxSupportedTopExperts, + "Routing kernel expects topK experts <= %d, got %d", MaxSupportedTopExperts, + data.mTopK); + FLASHINFER_CHECK(data.mNumExperts <= MaxSupportedExperts, + "Routing kernel expects #experts %d to be no more than %d", data.mNumExperts, + MaxSupportedExperts); + FLASHINFER_CHECK(data.mNumExperts % 4 == 0, + "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts); + + bool const useSingleBlock = + data.mNumTokens <= BlockKernelMaxNumTokens && data.mPtrTopKPacked == nullptr; + + bool const useSingleCluster = + data.mNumTokens <= ((data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr) + ? MaxNumTokensSingleClusterScores + : MaxNumTokensSingleCluster); + + if (!useSingleCluster && !useSingleBlock) { + FLASHINFER_CHECK( + (data.mPtrTopKPacked != nullptr || data.mPtrTopKIds != nullptr), + "When #tokens is large, `mPtrTopKPacked` or `mPtrTopKIds` is a required input."); + FLASHINFER_CHECK(data.mPtrExpertCounts != nullptr, + "When #tokens is large, `mPtrExpertCounts` is a required input."); + } + uint32_t const numThreadsHist = min(1024, getMaxNumExperts(data.mNumExperts)); + if (useSingleBlock) { + launchBlockKernel(data, numThreadsHist, stream); + } else if (useSingleCluster) { + launchClusterKernel(data, stream); + } else { + uint32_t const expandedIdxSize = data.mNumTokens * data.mTopK; + uint32_t const histogramEltsPerBlock = 8 * numThreadsHist; + uint32_t const offsetEltsPerBlock = NumEltsPerOffsetTilePerThread * numThreadsHist; + + uint32_t const maxNumBlocks = 1024; + + int const numBlocksHistogram = std::min( + (expandedIdxSize + histogramEltsPerBlock - 1) / histogramEltsPerBlock, maxNumBlocks); + int const numBlocksOffsets = + std::min((expandedIdxSize + offsetEltsPerBlock - 1) / offsetEltsPerBlock, maxNumBlocks); + + if (data.mPtrScores != nullptr && data.mPtrTopKIds == nullptr) { + launchHistogramScoresKernel(data, maxNumBlocks, numThreadsHist, stream); + } else { + launchInitExpertCounts(data, numThreadsHist, stream); + } + launchHistogramKernel(data, numBlocksHistogram, numThreadsHist, stream); + launchOffsetsKernel(data, numBlocksOffsets, numThreadsHist, stream); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace routingRenormalize +} // namespace moe::dev::routing diff --git a/csrc/moe_utils_binding.cu b/csrc/moe_utils_binding.cu index 8cfd00f3eb..0a6f4bb65b 100644 --- a/csrc/moe_utils_binding.cu +++ b/csrc/moe_utils_binding.cu @@ -291,8 +291,6 @@ void moe_sort( // Configure dtypes routingData.mDtypeExpW = batchedGemm::trtllm::gen::Dtype::Bfloat16; - routingData.mDtypeBias = batchedGemm::trtllm::gen::Dtype::Bfloat16; - routingData.mDtypeScore = batchedGemm::trtllm::gen::Dtype::Fp32; routingData.mUsePdl = use_pdl; // Input tensors (pre-computed expert selections) diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index 64fece5021..f3957e9717 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -867,7 +867,7 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { RoutingMethodType::Renormalize || static_cast(routing_method_type) == RoutingMethodType::RenormalizeNaive) { - TVM_FFI_ICHECK(args->top_k <= 10 && args->top_k > 0) + TVM_FFI_ICHECK(args->top_k <= 32 && args->top_k > 0) << "Current routing kernel (no groups, renormalize) only supports top_k<=10 && top_k>0."; } else if (static_cast(routing_method_type) == RoutingMethodType::Llama4) { TVM_FFI_ICHECK_EQ(args->top_k, 1) diff --git a/csrc/trtllm_fused_moe_routing_deepseek.cu b/csrc/trtllm_fused_moe_routing_deepseek.cu deleted file mode 100644 index 5408d2d059..0000000000 --- a/csrc/trtllm_fused_moe_routing_deepseek.cu +++ /dev/null @@ -1,695 +0,0 @@ -/* - * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include "flashinfer/exception.h" -#include "flashinfer/trtllm/fused_moe/RoutingKernel.cuh" - -namespace moe::dev::routing { - -namespace routingDeepSeek { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static constexpr int NumNemotronExperts = 512; -static constexpr int NumKimiK2Experts = 384; -static constexpr int NumDeepseekExperts = 256; -static constexpr int MaxSupportedExpertCount = - std::max({NumNemotronExperts, NumKimiK2Experts, NumDeepseekExperts}); -static constexpr int NumTopGroupScores = 2; -static constexpr int DefaultMaxNumTopExperts = 8; -static constexpr int MaxSupportedTopExperts = 22; -static constexpr int MaxNumTopGroups = 4; -static constexpr int MaxNumGroups = 8; - -template -__global__ void routingMainKernel(KernelParams params) { - // declare types - using OutputT = typename KernelParams::OutputT; - using InputT = typename KernelParams::InputT; - - // declare shared memory structure - // number of experts is bounded by number of threads - __shared__ float __attribute((aligned(128))) smemScoreSigmoid[KernelParams::MaxNumExperts]; - __shared__ float __attribute((aligned(128))) smemScoreBias[KernelParams::MaxNumExperts]; - // number of expert groups is bounded by number of warps - __shared__ float __attribute((aligned(128))) smemGroupScores[MaxNumGroups]; - - // needed for warp reduce - auto block = cg::this_thread_block(); - auto warp = cg::tiled_partition(block); - // for the final reduction of weight norm, only some lanes need to participate - int32_t laneIdx = threadIdx.x % WarpSize; - int32_t warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); - // warps outside the range of expert groups do not participate - if constexpr (KernelParams::UseGroups) { - if (warpIdx >= params.mNumExpertGroups) { - return; - } - } - - // note that for invalid scores, we use negative infinity, - // needed for GLM-style routing where bias can be negative - static constexpr float invalidScoreFloat = -float(INFINITY); - const OutputT invalidScore = OutputT{invalidScoreFloat}; - - // load bias already; each warp represents one expert group - auto threadExpert = threadIdx.x; - bool expertSelected = threadExpert < params.mNumExperts; - if constexpr (KernelParams::UseGroups) { - threadExpert = warpIdx * params.mNumExpertsPerGroup + laneIdx; - expertSelected = laneIdx < params.mNumExpertsPerGroup; - } - auto scoreIdx = int64_t{blockIdx.x} * int64_t{params.mNumExperts} + threadExpert; - auto biasVal = - expertSelected ? static_cast(params.mPtrRoutingBias[threadExpert]) : invalidScoreFloat; - // initialize the mPtrExpertCounts - if (params.mPtrExpertCounts) { - int32_t globalThreadIdx = blockIdx.x * blockDim.x + threadIdx.x; - int32_t globalThreadStride = gridDim.x * blockDim.x; - int32_t expertCountsNum = 2 * params.mNumExperts; - initArr(globalThreadIdx, expertCountsNum, globalThreadStride, params.mPtrExpertCounts, 0); - } - -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) - // trigger the secondary kernel when using PDL, then wait on primary - if constexpr (KernelParams::UsePdl) { - cudaTriggerProgrammaticLaunchCompletion(); - cudaGridDependencySynchronize(); - } -#endif - - if (params.mPtrScores != nullptr) { - // get our assigned thread score; each warp represents one expert group - float score = - expertSelected ? static_cast(params.mPtrScores[scoreIdx]) : invalidScoreFloat; - // get the sigmoid score - // note that for invalid values, we simply use a negative value: - // sigmoig scores are always strictly positive - auto scoreSigmoid = sigmoid_accurate(score); - // write the sigmoid score to shared for later use - if (expertSelected) { - smemScoreSigmoid[threadExpert] = scoreSigmoid; - } - // get the score with bias - // note: with invalid values, invalidScoreFloat ensures values are always smaller than valid - // ones - auto scoreBias = float{scoreSigmoid + float{biasVal}}; - - if (expertSelected) { - smemScoreBias[threadExpert] = scoreBias; - } - - // registers for top group score reduction - float topExpGroupScores[NumTopGroupScores]; - [[maybe_unused]] int32_t topExpGroupIdx[NumTopGroupScores]; - float topGroups[MaxNumTopGroups]; // bound of params.mNumLimitedGroups - int32_t topGroupIdx[MaxNumTopGroups]; - float expertScoreGroup[MaxNumTopGroups]; - int32_t expertIdxGroup[MaxNumTopGroups]; - float topScores[KernelParams::MaxNumTopExperts]; // bound of params.mTopK - int32_t topExperts[KernelParams::MaxNumTopExperts]; - - if constexpr (KernelParams::UseGroups) { - topk::reduceTopK(warp, topExpGroupScores, topExpGroupIdx, scoreBias, threadExpert, - /* minValue */ invalidScoreFloat); - // get the final group score and write it to shared - if (cute::elect_one_sync()) { - auto groupScore = topExpGroupScores[0] + topExpGroupScores[1]; - smemGroupScores[warpIdx] = groupScore; - } - } - - // make group scores available to all warps - __syncthreads(); - - auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2; - if constexpr (KernelParams::UseGroups) { // a single warp performs the selection of top groups, - // and goes on to select the final experts - if (warpIdx == 0) { - float groupScore = - laneIdx < params.mNumExpertGroups ? smemGroupScores[laneIdx] : invalidScoreFloat; - topk::reduceTopK(warp, topGroups, topGroupIdx, groupScore, laneIdx, - /* minValue */ invalidScoreFloat); - // final expert selection: get relevant indexes and scores from shared - -#pragma unroll - for (int ii = 0; ii < MaxNumTopGroups; ++ii) { // bound of params.mNumLimitedGroups - auto groupIdx = topGroupIdx[ii]; - expertIdxGroup[ii] = groupIdx * params.mNumExpertsPerGroup + laneIdx; - // note: expertSelected implies laneIdx < params.mNumExpertsPerGroup. - // we have params.mNumExpertsPerGroup == params.mNumExperts / params.mNumExpertGroups, - // thus groupIdx <= params.mNumExpertGroups - 1 => - // groupIdx * params.mNumExpertsPerGroup <= params.mNumExperts - - // params.mNumExpertsPerGroup - // => expertIdxGroup[ii] < params.mNumExperts <= NumThreads, - // so the access is safe here - expertScoreGroup[ii] = (ii < params.mNumLimitedGroups) && - (groupIdx < params.mNumExpertGroups) && expertSelected - ? smemScoreBias[expertIdxGroup[ii]] - : invalidScoreFloat; - } - - topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup, expertIdxGroup, - /* minValue */ invalidScoreFloat, params.mTopK); - } - } else if constexpr (KernelParams::MaxNumExperts > topk::MaxNumExpertsUnit) { - // without groups, each thread just takes `MaxNumTopGroups` experts - int constexpr NumExpertWarps = - (KernelParams::MaxNumExperts - 1) / topk::MaxNumExpertsUnit + 1; - int constexpr NumInterTopK = NumExpertWarps * KernelParams::MaxNumTopExperts; - __shared__ float __attribute((aligned(128))) smemInterTopScores[NumInterTopK]; - __shared__ int32_t __attribute((aligned(128))) smemInterTopExperts[NumInterTopK]; - if (warpIdx < NumExpertWarps) { - int offset = warpIdx * WarpSize * MaxNumTopGroups; -#pragma unroll - for (int ii = 0; ii < MaxNumTopGroups; ++ii) { - auto expertIdx = ii * WarpSize + laneIdx; - expertIdxGroup[ii] = offset + expertIdx; - expertScoreGroup[ii] = offset + expertIdx < params.mNumExperts - ? smemScoreBias[offset + expertIdx] - : invalidScoreFloat; - } - topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup, expertIdxGroup, - /* minValue */ invalidScoreFloat, params.mTopK); - - if (laneIdx < params.mTopK) { - smemInterTopScores[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] = - topScores[laneIdx]; - smemInterTopExperts[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] = - topExperts[laneIdx]; - } else if (laneIdx >= params.mTopK && laneIdx < KernelParams::MaxNumTopExperts) { - smemInterTopScores[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] = - invalidScoreFloat; - smemInterTopExperts[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] = - MaxSupportedExpertCount - 1; - } - } - __syncthreads(); - if (warpIdx == 0) { - int constexpr NumInterTopKPerThread = (NumInterTopK - 1) / WarpSize + 1; - float intermidiateScore[NumInterTopKPerThread]; - int32_t intermidiateExpert[NumInterTopKPerThread]; - for (int i = laneIdx; i < NumInterTopKPerThread * WarpSize; i += WarpSize) { - int ii = i / WarpSize; - if (i < NumInterTopK) { - intermidiateScore[ii] = smemInterTopScores[i]; - intermidiateExpert[ii] = smemInterTopExperts[i]; - } else { - intermidiateScore[ii] = invalidScoreFloat; - intermidiateExpert[ii] = KernelParams::MaxNumExperts - 1; - } - } - topk::reduceTopK(warp, topScores, topExperts, intermidiateScore, intermidiateExpert, - /* minValue */ invalidScoreFloat, params.mTopK); - } - } else { - if (warpIdx == 0) { - // without groups, each thread just takes `MaxNumTopGroups` experts -#pragma unroll - for (int ii = 0; ii < MaxNumTopGroups; ++ii) { - auto expertIdx = ii * WarpSize + laneIdx; - expertIdxGroup[ii] = expertIdx; - expertScoreGroup[ii] = - expertIdx < params.mNumExperts ? smemScoreBias[expertIdx] : invalidScoreFloat; - } - topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup, expertIdxGroup, - /* minValue */ invalidScoreFloat, params.mTopK); - } - } - - if (warpIdx == 0) { - // determine our lane's expert index and write to output - int32_t expertIdx = 0; -#pragma unroll - for (int ii = 0; ii < params.mTopK; ++ii) { // bound of params.mTopK - expertIdx = laneIdx == ii ? topExperts[ii] : expertIdx; - } - // determine whether our expert is local to this GPU - auto localExpertIdx = expertIdx - params.mLocalExpertsStartIdx; - auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent && - (localExpertIdx & params.mLocalExpertsStrideLog2) == 0; - - float scoreNorm = laneIdx < params.mTopK ? smemScoreSigmoid[expertIdx] : 0.F; - auto redNorm = cg::reduce(warp, scoreNorm, cg::plus{}); - auto finalScore = OutputT{scoreNorm * params.mRouteScale / redNorm}; - - // write expert idx out already - auto idxTopK = blockIdx.x * params.mTopK + laneIdx; - if (laneIdx < params.mTopK && params.mPtrTopKPacked != nullptr) { - PackedScoreIdx packedScore{static_cast(finalScore), - static_cast(expertIdx)}; - params.mPtrTopKPacked[idxTopK] = packedScore; - } - - if (laneIdx < params.mTopK && params.mPtrTopKWeights != nullptr && - params.mPtrTopKIds == nullptr) { - params.mPtrTopKWeights[idxTopK] = finalScore; - } - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) -__global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) - __launch_bounds__(KernelParams::MaxNumExperts) - routingIndicesClusterKernel(KernelParams params) { - using OutputT = typename KernelParams::OutputT; - - int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); - int32_t const clusterBlockRank = blockIdx.x; - - //@todo: try to move it into routingPermutation - // then wait on primary grid - if constexpr (KernelParams::UsePdl) { - cudaGridDependencySynchronize(); - } - routingPermutation(params, nullptr, warpIdx, clusterBlockRank); -} -#else -__global__ void routingIndicesClusterKernel(KernelParams params) { - assert(false && "routingIndicesClusterKernel is only supported on SM90+ architectures"); -} -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) -__global__ void __launch_bounds__(KernelParams::MaxNumExperts) - routingIndicesCoopKernel(KernelParams params) { - // number of experts is bounded by number of threads - int constexpr NumThreads = KernelParams::MaxNumExperts; - __shared__ int32_t __attribute((aligned(128))) smemExpertCount[NumThreads]; - __shared__ int32_t __attribute((aligned(128))) smemExpertOffset[NumThreads]; - // needed for the exclusive sum of token offsets - using Scan = cub::BlockScan; - __shared__ typename Scan::TempStorage tempStorage; - // 64 elements -> 128+ registers. Above that we may start to see spilling to local memory. - static constexpr int MaxExpandedIdxPerThread = 64; - - // Initialize grid. - cg::grid_group grid = cg::this_grid(); - // Note: the following is more efficient than grid.block_index() because we don't use y and z. - int32_t const gridBlockIdx = blockIdx.x; - int32_t const gridThreadIdx = NumThreads * gridBlockIdx + threadIdx.x; - int32_t const numBlocks = gridDim.x; - int32_t const numThreadsPerGrid = numBlocks * NumThreads; - - int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); - - auto expandedIdxSize = params.mNumTokens * params.mTopK; - - // pre-fill the counts with 0 - smemExpertCount[threadIdx.x] = 0; - __syncthreads(); - - // then wait on primary grid - if constexpr (KernelParams::UsePdl) { - cudaGridDependencySynchronize(); - } - - // each thread keeps has some number of "expanded indexes" assigned to it - // for each of these, we keep the associated expert and offset within expert in registers - int32_t expertIndexes[MaxExpandedIdxPerThread]; - int32_t expertOffsets[MaxExpandedIdxPerThread]; - auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2; - // In order to avoid a serialization LDG-ATOMS-LDG-ATOMS-..., we skip multiple iterations at a - // time, and branch between a fast path without bound checks and a slow path with bound checks. - int constexpr IterStride = 4; - static_assert(MaxExpandedIdxPerThread % IterStride == 0); - - // Define a lambda to avoid code duplication in both branches. - auto loopBody = [&](int ii, int expandedIdx) { - int32_t expertIdx = params.mPtrTopKIds != nullptr ? params.mPtrTopKIds[expandedIdx] - : params.mPtrTopKPacked[expandedIdx].idx; - expertIndexes[ii] = expertIdx; - // check whether this expert is local to our GPU at all and ignore if not - auto localExpertIdx = expertIdx - params.mLocalExpertsStartIdx; - auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent && - (localExpertIdx & params.mLocalExpertsStrideLog2) == 0; - expertOffsets[ii] = isLocalExpert ? atomicAdd(smemExpertCount + expertIdx, 1) : 0; - }; - -#pragma unroll - for (int32_t ii0 = 0; ii0 < MaxExpandedIdxPerThread; ii0 += IterStride) { - // Whether it's safe to do multiple iterations without bound checks. - bool const takeFastPath = (ii0 + IterStride) * numThreadsPerGrid <= expandedIdxSize; - if (takeFastPath) { -#pragma unroll - for (int32_t jj = 0; jj < IterStride; jj++) { - int const ii = ii0 + jj; - auto expandedIdx = static_cast(gridThreadIdx) + ii * numThreadsPerGrid; - loopBody(ii, expandedIdx); - } - } else { - bool doBreak = false; -#pragma unroll - for (int32_t jj = 0; jj < IterStride; jj++) { - int const ii = ii0 + jj; - auto expandedIdx = static_cast(gridThreadIdx) + ii * numThreadsPerGrid; - if (expandedIdx >= expandedIdxSize) { - doBreak = true; - break; - } - loopBody(ii, expandedIdx); - } - if (doBreak) { - break; - } - } - } - - // Make histogram (token counts per expert) available to all threads in the block. - __syncthreads(); - - // - // Each thread now represents one expert - // - - // Add the local bin count to the common bin count and get a per-CTA offset. - int32_t const localExpertCount = smemExpertCount[threadIdx.x]; - - int32_t blockExpertOffset = 0; - if (threadIdx.x < params.mNumExperts) { - blockExpertOffset = atomicAdd(¶ms.mPtrExpertCounts[threadIdx.x], localExpertCount); - } - - // Sync to wait for completion of the histogram reduction. - grid.sync(); - - // Get total count for this expert. - int32_t count = (threadIdx.x < params.mNumExperts) ? params.mPtrExpertCounts[threadIdx.x] : 0; - - // Note: the scan is redundant in all CTAs, but doing it in only 1 CTA would be worse for latency. - - // Compute the runtime config for projections - // Whether or not an expert is local is taken into account when smemExpertCount is computed - // so we do not need to take it into account here. - - int32_t numCta; - if constexpr (KernelParams::isPow2) { - numCta = divUpLog2(count, params.mPaddingLog2); - } else { - numCta = divUpTileN(count, params.mTileTokensDim); - } - - int32_t ctaOffset; - int32_t numNonExitingCtas; - Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas); - - for (int32_t cta = gridBlockIdx; cta < numCta; cta += numBlocks) { - const int32_t localExpertIdx = - (threadIdx.x - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2; - params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx; - int32_t mnLimit1; - int32_t mnLimit2; - if constexpr (KernelParams::isPow2) { - mnLimit1 = mulLog2(ctaOffset + cta + 1, params.mPaddingLog2); - mnLimit2 = mulLog2(ctaOffset, params.mPaddingLog2) + count; - } else { - mnLimit1 = mulTileN(ctaOffset + cta + 1, params.mTileTokensDim); - mnLimit2 = mulTileN(ctaOffset, params.mTileTokensDim) + count; - } - params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = min(mnLimit1, mnLimit2); - } - - // get the padded offset associated with this expert - int32_t offset; - if constexpr (KernelParams::isPow2) { - offset = mulLog2(ctaOffset, params.mPaddingLog2); - } else { - offset = mulTileN(ctaOffset, params.mTileTokensDim); - } - int32_t permutedIdxSize; - if constexpr (KernelParams::isPow2) { - permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); - } else { - permutedIdxSize = mulTileN(numNonExitingCtas, params.mTileTokensDim); - } - - // write out padded count - if (gridBlockIdx == 0 && warpIdx == NumThreads / WarpSize - 1 && cute::elect_one_sync()) { - params.mPtrPermutedIdxSize[0] = permutedIdxSize; - params.mPtrNumNonExitingCtas[0] = numNonExitingCtas; - } - - // write expert offsets to shared - smemExpertOffset[threadIdx.x] = offset + blockExpertOffset; - - // make expert offsets available to all threads - __syncthreads(); - - // trigger the secondary kernel when using PDL - // We can't do it earlier because FC1 depends on the mPtrCtaIdxXyToBatchIdx, - // mPtrCtaIdxXyToMnLimit, mPtrNumNonExitingCtas and mPtrTotalNumPaddedTokens - // TODO: this is not sufficient to ensure visibility in the next kernel! - if constexpr (KernelParams::UsePdl) { - cudaTriggerProgrammaticLaunchCompletion(); - } - -// each thread has the same "expanded indexes" assigned to it as above -// at this point, we know the final offsets of experts and the offsets within -// experts, which allows writing the final index values -#pragma unroll - for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ++ii) { - auto expandedIdx = static_cast(gridThreadIdx) + ii * numThreadsPerGrid; - if (expandedIdx >= expandedIdxSize) { - break; - } - auto expertIdx = expertIndexes[ii]; - // check whether this expert is local to our GPU at all - auto localExpertIdx = static_cast(expertIdx) - params.mLocalExpertsStartIdx; - auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent && - (localExpertIdx & params.mLocalExpertsStrideLog2) == 0; - auto tokenIdx = expandedIdx / params.mTopK; - auto permutedIdx = - isLocalExpert ? int32_t{smemExpertOffset[expertIdx]} + expertOffsets[ii] : int32_t{-1}; - if (params.mPtrExpandedIdxToPermutedIdx != nullptr) { - params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx; - } - if (params.mPtrPermutedIdxToExpandedIdx != nullptr && isLocalExpert) { - params.mPtrPermutedIdxToExpandedIdx[permutedIdx] = expandedIdx; - } - if (params.mPtrPermutedIdxToTokenIdx != nullptr && isLocalExpert) { - params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx; - } - } -} -#else -__global__ void routingIndicesCoopKernel(KernelParams params) { - assert(false && "routingIndicesCoopKernel is only supported on SM90+ architectures"); -} -#endif - -int constexpr getMaxNumExperts(int32_t numExperts) { - if (numExperts <= topk::MaxNumExpertsUnit) { - return topk::MaxNumExpertsUnit; - } else if (numExperts <= NumDeepseekExperts) { - return NumDeepseekExperts; - } else if (numExperts <= NumKimiK2Experts) { - return NumKimiK2Experts; - } else if (numExperts <= NumNemotronExperts) { - return NumNemotronExperts; - } else { - TLLM_LOG_ERROR("Unsupported numExperts"); - return 0; - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -#define LAUNCH_ROUTING_DEEPSEEK(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, \ - extraFlag) \ - if (data.mNumExperts <= topk::MaxNumExpertsUnit) { \ - LAUNCH_ROUTING_DEEPSEEK_IMPL(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag, topk::MaxNumExpertsUnit, \ - DefaultMaxNumTopExperts); \ - } else if (data.mNumExperts <= NumDeepseekExperts) { \ - LAUNCH_ROUTING_DEEPSEEK_IMPL(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag, NumDeepseekExperts, DefaultMaxNumTopExperts); \ - } else if (data.mNumExperts <= NumKimiK2Experts) { \ - LAUNCH_ROUTING_DEEPSEEK_IMPL(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag, NumKimiK2Experts, DefaultMaxNumTopExperts); \ - } else if (data.mNumExperts <= NumNemotronExperts) { \ - if (data.mTopK <= DefaultMaxNumTopExperts) { \ - LAUNCH_ROUTING_DEEPSEEK_IMPL(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag, NumNemotronExperts, \ - DefaultMaxNumTopExperts); \ - } else { \ - LAUNCH_ROUTING_DEEPSEEK_IMPL(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag, NumNemotronExperts, MaxSupportedTopExperts); \ - } \ - } else { \ - TLLM_LOG_ERROR("Unsupported numExperts"); \ - } - -void runImpl(Data& data, void* stream) { - FLASHINFER_CHECK( - data.mPtrTopKPacked != nullptr || data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr, - "Routing kernel requires at least one input parameter"); - if (data.mPtrTopKIds != nullptr) { - FLASHINFER_CHECK(data.mPtrTopKWeights != nullptr, - "When mPtrTopKIds is provided, mPtrTopKWeights must also be provided for " - "DeepSeek routing."); - } - if (data.mPtrExpandedIdxToPermutedIdx != nullptr || - data.mPtrPermutedIdxToExpandedIdx != nullptr || data.mPtrPermutedIdxToTokenIdx != nullptr) - FLASHINFER_CHECK( - (data.mPtrTopKPacked != nullptr || data.mPtrTopKIds != nullptr) && data.mPtrPermutedIdxSize, - "If permuted index is required, `mPtrTopKPacked` or `mPtrTopKIds` is also required"); - FLASHINFER_CHECK(!data.mUseRoutingSoftmax, "Routing with softmax not implemented yet"); - FLASHINFER_CHECK(data.mNumLimitedGroups <= MaxNumTopGroups, - "Routing kernel expects <= %d top groups, got %d", MaxNumTopGroups, - data.mNumLimitedGroups); - // Test limits according to values passed in launch, see definition of LAUNCH_ROUTING_DEEPSEEK - if (data.mNumExperts <= NumKimiK2Experts) { - FLASHINFER_CHECK( - data.mTopK <= DefaultMaxNumTopExperts, - "When NumExperts <= NumKimiK2Experts, routing kernel expects topK experts <= %d, got %d", - DefaultMaxNumTopExperts, data.mTopK); - } else { - FLASHINFER_CHECK( - data.mTopK <= MaxSupportedTopExperts, - "When NumExperts > NumKimiK2Experts, routing kernel expects topK experts <= %d, got %d", - MaxSupportedTopExperts, data.mTopK); - } - FLASHINFER_CHECK(data.mTopK <= WarpSize, "Routing kernel expects top K <= warp size, got %d", - data.mTopK); - FLASHINFER_CHECK(data.mTopK * data.mNumLimitedGroups <= WarpSize, - "Routing kernel expects top K * top groups <= warp size (for now), got %d * %d", - data.mTopK, data.mNumLimitedGroups); - FLASHINFER_CHECK(data.mTopK <= data.mNumExperts, - "Routing kernel expects topK %d to be at most #experts %d", data.mTopK, - data.mNumExperts); - FLASHINFER_CHECK(data.mNumExperts <= MaxSupportedExpertCount, - "Routing kernel expects #experts %d <= #threads %d", data.mNumExperts, - MaxSupportedExpertCount); - FLASHINFER_CHECK(data.mNumExpertGroups >= data.mNumLimitedGroups, - "Routing kernel expects top groups %d to be limited by #expert groups %d", - data.mNumLimitedGroups, data.mNumExpertGroups); - if (data.mNumExpertGroups > 1) { - FLASHINFER_CHECK(data.mNumExpertGroups <= MaxNumGroups, - "Routing kernel expects #experts groups %d to be <= #warps %d", - data.mNumExpertGroups, MaxNumGroups); - FLASHINFER_CHECK(data.mNumExperts % data.mNumExpertGroups == 0, - "Routing kernel expects #experts %d to be a multiple of #expert groups %d", - data.mNumExperts, data.mNumExpertGroups); - FLASHINFER_CHECK( - data.mNumExperts / data.mNumExpertGroups <= WarpSize, - "Routing kernel expects #experts per group <= warp size, got %d, data.mNumExpertGroups %d", - data.mNumExperts / data.mNumExpertGroups, data.mNumExpertGroups); - } - FLASHINFER_CHECK(data.mNumExperts % 4 == 0, - "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts); - - int const numBlocks = data.mNumTokens; - int const numThreadsHist = getMaxNumExperts(data.mNumExperts); - - bool const useSingleCluster = data.mNumTokens <= 1024; - if (!useSingleCluster) { - // Reset the global histograms (not used in single-cluster code path). - // Cover both for the cooperative and two-kernel code paths. - FLASHINFER_CHECK(data.mPtrExpertCounts != nullptr, - "When #tokens is large, `mPtrExpertCounts` is a required input."); - } else { - data.mPtrExpertCounts = - nullptr; // Set it to nullptr for single-cluster code path, as it won't be used - } - - // Number of blocks we can use in the cooperative kernel - // The number of blocks must be: - // >= ⌈(numTokens * topK) / (MaxExpandedIdxPerThread * NumThreads)⌉ - // <= numSms, assuming an occupancy of 1 block/SM - // - // If too small for the given numTokens, fall back to the less performant two-step method. - // - // The upper bound is a strict requirement. The number of blocks should be determined by querying - // the device properties, or conservatively low. - // /!\ The following number is not portable!! (but works on H100 and B200) - int const numBlocksCoop = 128; - - // Maximum number of tokens supported by the kernel using a cooperative launch. - int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK; - if (data.mPtrTopKIds == nullptr) { - int const numThreadsMain = - max(data.mNumExpertGroups * WarpSize, getMaxNumExperts(data.mNumExperts)); - LAUNCH_ROUTING_DEEPSEEK(data, - /*coopLaunch=*/false, routingMainKernel, numBlocks, numThreadsMain, - /*smemSize=*/0, // No dynamic smem - stream, data.mNumExpertGroups > 1); - } else { - // Reset the global histograms. - LAUNCH_ROUTING_DEEPSEEK(data, false, routingInitExpertCounts, - (2 * data.mNumExperts - 1) / numThreadsHist + 1, numThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream, data.mNumExpertGroups > 1); - } - - if (data.mPtrPermutedIdxSize != nullptr) { - if (useSingleCluster) { - LAUNCH_ROUTING_DEEPSEEK(data, - /*coopLaunch=*/false, routingIndicesClusterKernel, - NumBlocksPerCluster, numThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream, data.mNumExpertGroups > 1); - } else if (data.mNumTokens <= maxTokensCoop) { - LAUNCH_ROUTING_DEEPSEEK(data, - /*coopLaunch=*/true, routingIndicesCoopKernel, numBlocksCoop, - numThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream, data.mNumExpertGroups > 1); - } else { - const int32_t expandedIdxSize = data.mNumTokens * data.mTopK; - const int32_t histogramEltsPerBlock = 8 * numThreadsHist; - const int32_t offsetEltsPerBlock = NumEltsPerOffsetTilePerThread * numThreadsHist; - - // Limit grid size (both kernels use a grid-stride loop). - const int32_t maxNumBlocks = 1024; - - int const numBlocksHistogram = std::min( - (expandedIdxSize + histogramEltsPerBlock - 1) / histogramEltsPerBlock, maxNumBlocks); - int const numBlocksOffsets = - std::min((expandedIdxSize + offsetEltsPerBlock - 1) / offsetEltsPerBlock, maxNumBlocks); - - LAUNCH_ROUTING_DEEPSEEK(data, - /*coopLaunch=*/false, routingIndicesHistogramKernel, - numBlocksHistogram, numThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream, data.mNumExpertGroups > 1); - LAUNCH_ROUTING_DEEPSEEK(data, - /*coopLaunch=*/false, routingIndicesOffsetsKernel, numBlocksOffsets, - numThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream, data.mNumExpertGroups > 1); - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void run(Data& data, void* stream) { runImpl(data, stream); } - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace routingDeepSeek -} // namespace moe::dev::routing diff --git a/csrc/trtllm_fused_moe_routing_renormalize.cu b/csrc/trtllm_fused_moe_routing_renormalize.cu deleted file mode 100644 index 364c267c00..0000000000 --- a/csrc/trtllm_fused_moe_routing_renormalize.cu +++ /dev/null @@ -1,506 +0,0 @@ -/* - * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "flashinfer/trtllm/fused_moe/RoutingKernel.cuh" -#include "tvm_ffi_utils.h" - -namespace moe::dev::routing { -namespace routingRenormalize { -//////////////////////////////////////////////////////////////////////////////////////////////////// - -static constexpr int NumThreads = 1024; -static constexpr int NumWarps = NumThreads / WarpSize; -static constexpr int MaxNumTopExperts = 10; -static constexpr int NumExpertsLimit = 512; -static constexpr int MaxNumTokensSingleCluster = NumBlocksPerCluster * NumThreads; -static constexpr int MaxNumTokensSingleClusterScores = NumBlocksPerCluster * NumWarps; -static constexpr int BlockKernelMaxNumTokens = 4; - -template -__forceinline__ __device__ void routingTopKExperts( - cg::thread_block_tile const& warp, DataType (&score)[VecSize], - int32_t (&idx)[VecSize], DataType (&warpTopKScore)[MaxNumTopExperts], - int32_t (&warpTopKExpertIdx)[MaxNumTopExperts], int32_t const laneIdx, int32_t const numExperts, - int32_t topK, InputType const* ptrScores, bool const normTopkProb, - bool const applySoftmaxAfterTopK) { - DataType minScore = DataType{-INFINITY}; - - for (int i = 0; i < VecSize; i++) { - auto expertIdx = i * WarpSize + laneIdx; - auto newScore = expertIdx < numExperts ? static_cast(ptrScores[expertIdx]) : minScore; - score[i] = newScore; - idx[i] = expertIdx; - } - if constexpr (DoSoftmaxBeforeTopK) { - calcSoftmax(warp, score); - } - - // Get the top-k scores and their corresponding expert indices - topk::reduceTopK(warp, warpTopKScore, warpTopKExpertIdx, score, idx, minScore, topK); - - // Normalize the scores - if constexpr (DoSoftmaxBeforeTopK) { - float sum = float{1.f}; - if (normTopkProb) { - sum = static_cast(laneIdx < topK ? warpTopKScore[laneIdx] : 0); - sum = cg::reduce(warp, sum, cg::plus()); - } - if (laneIdx < topK) { - warpTopKScore[laneIdx] = warpTopKScore[laneIdx] / sum; - } - } else { - if (applySoftmaxAfterTopK) { - auto softmaxScore = - calcSoftmax(warp, laneIdx < topK ? warpTopKScore[laneIdx] : minScore, laneIdx, topK); - if (laneIdx < topK) { - warpTopKScore[laneIdx] = softmaxScore; - } - } - // If applySoftmaxAfterTopK is false, we keep the raw TopK values without softmax - } -} - -template -__global__ void __launch_bounds__(KernelParams::MaxNumExperts) - routingIndicesBlockKernel(KernelParams params) { - // types used in this kernel - using OutputT = typename KernelParams::OutputT; - using InputT = typename KernelParams::InputT; - using BaseType = std::conditional_t; - using TypePacked = PackedScoreIdx; - int constexpr MaxNumExperts = KernelParams::MaxNumExperts; - - int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); - int32_t const laneIdx = cutlass::arch::LaneId(); - int32_t const expert = threadIdx.x; - auto scoreOffset = warpIdx * params.mNumExperts; - bool validToken = warpIdx < params.mNumTokens; - - static constexpr int VecSize = KernelParams::MaxNumExperts / WarpSize; - static constexpr int totalExpertCounts = BlockKernelMaxNumTokens * MaxNumExperts; - __shared__ int8_t __attribute((aligned(128))) smemOffset[totalExpertCounts]; - __shared__ int8_t __attribute((aligned(128))) smemKIdx[totalExpertCounts]; - - using Scan = cub::BlockScan; - __shared__ typename Scan::TempStorage tempStorage; - - auto block = cg::this_thread_block(); - auto warp = cg::tiled_partition(block); - - for (int i = threadIdx.x; i < totalExpertCounts; i += blockDim.x) { - smemOffset[i] = int8_t{-1}; - smemKIdx[i] = int8_t{-1}; - } - __syncthreads(); - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - // then wait on primary grid - if constexpr (KernelParams::UsePdl) { - cudaGridDependencySynchronize(); - } -#endif - - if (params.mPtrTopKIds != nullptr) { - if (validToken) { - if (laneIdx < params.mTopK) { - int offset = warpIdx * MaxNumExperts + params.mPtrTopKIds[warpIdx * params.mTopK + laneIdx]; - smemKIdx[offset] = static_cast(laneIdx); - } - } - } else if (params.mPtrScores != nullptr) { - // in this case, each warp represents a token - BaseType score[VecSize]; - int32_t idx[VecSize]; - - BaseType warpTopKScore[MaxNumTopExperts]; - int32_t warpTopKExpertIdx[MaxNumTopExperts]; - - BaseType minScore = BaseType{-INFINITY}; - if (validToken) { - routingTopKExperts( - warp, score, idx, warpTopKScore, warpTopKExpertIdx, laneIdx, params.mNumExperts, - params.mTopK, params.mPtrScores + scoreOffset, params.mNormTopkProb, - params.mApplySoftmaxAfterTopK); - - if (laneIdx < params.mTopK) { - int offset = warpIdx * MaxNumExperts + warpTopKExpertIdx[laneIdx]; - smemKIdx[offset] = static_cast(laneIdx); - if (params.mPtrTopKWeights != nullptr) { - params.mPtrTopKWeights[warpIdx * params.mTopK + laneIdx] = - OutputT{warpTopKScore[laneIdx]}; - } - } - } // end if (validToken) - } else if (params.mPtrTopKPacked != nullptr) { - if (validToken) { - if (laneIdx < params.mTopK) { - int offset = warpIdx * MaxNumExperts + - static_cast(params.mPtrTopKPacked[warpIdx * params.mTopK + laneIdx].idx); - smemKIdx[offset] = static_cast(laneIdx); - if (params.mPtrTopKWeights != nullptr) { - params.mPtrTopKWeights[warpIdx * params.mTopK + laneIdx] = - static_cast(params.mPtrTopKPacked[warpIdx * params.mTopK + laneIdx].score); - } - } - } - } - __syncthreads(); - - // set local experts - auto localExpertIdx = expert - params.mLocalExpertsStartIdx; - auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < params.mNumLocalExperts && - (localExpertIdx & params.mLocalExpertsStrideLog2) == 0; - // Get the count of each expert and the offset for each token - int accExpertCount = 0; - - if (isLocalExpert) { - int offset = expert; - for (int j = 0; j < BlockKernelMaxNumTokens; j++) { - if (smemKIdx[offset] >= 0) { - smemOffset[offset] = static_cast(accExpertCount); - accExpertCount++; - } - offset += MaxNumExperts; - } - } - __syncthreads(); - // Get the number of CTAs and the offset for each CTA - int32_t numCta; - if constexpr (KernelParams::isPow2) { - numCta = divUpLog2(accExpertCount, params.mPaddingLog2); - } else { - numCta = divUpTileN(accExpertCount, params.mTileTokensDim); - } - int32_t ctaOffset = 0; - int32_t numNonExitingCtas; - Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas); - - int32_t expertScanCounts = 0; - int32_t tmpCount; - if constexpr (KernelParams::isPow2) { - tmpCount = divUpMulLog2(accExpertCount, params.mPaddingLog2); - } else { - tmpCount = divUpMulTileN(accExpertCount, params.mTileTokensDim); - } - Scan(tempStorage).ExclusiveSum(tmpCount, expertScanCounts); - __syncthreads(); - - if (isLocalExpert) { - for (int cta = 0; cta < numCta; ++cta) { - const int32_t localExpertIdx = - (expert - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2; - params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx; - int32_t mnLimit1; - int32_t mnLimit2; - if constexpr (KernelParams::isPow2) { - mnLimit1 = mulLog2(ctaOffset + cta + 1, params.mPaddingLog2); - mnLimit2 = mulLog2(ctaOffset, params.mPaddingLog2) + accExpertCount; - } else { - mnLimit1 = mulTileN(ctaOffset + cta + 1, params.mTileTokensDim); - mnLimit2 = mulTileN(ctaOffset, params.mTileTokensDim) + accExpertCount; - } - params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = min(mnLimit1, mnLimit2); - } - } - - // at this point, we can write out padded count - if (threadIdx.x == 0) { - int32_t permutedIdxSize; - if constexpr (KernelParams::isPow2) { - permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); - } else { - permutedIdxSize = mulTileN(numNonExitingCtas, params.mTileTokensDim); - } - params.mPtrPermutedIdxSize[0] = permutedIdxSize; - params.mPtrNumNonExitingCtas[0] = numNonExitingCtas; - } - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - // we can trigger the next kernel at this point - if constexpr (KernelParams::UsePdl) { - cudaTriggerProgrammaticLaunchCompletion(); - } -#endif - - for (int tokenIdx = 0; tokenIdx < params.mNumTokens; tokenIdx++) { - int offset = tokenIdx * MaxNumExperts + threadIdx.x; - if (smemKIdx[offset] >= 0) { - int const expandedIdx = tokenIdx * params.mTopK + smemKIdx[offset]; - int const offsetWithinExpert = static_cast(smemOffset[offset]); - int const offsetForExpert = expertScanCounts; - int const permutedIdx = isLocalExpert ? offsetForExpert + offsetWithinExpert : int32_t{-1}; - - params.mPtrExpandedIdxToPermutedIdx[expandedIdx] = permutedIdx; - if (isLocalExpert) { - if (params.mPtrPermutedIdxToExpandedIdx != nullptr) { - params.mPtrPermutedIdxToExpandedIdx[permutedIdx] = expandedIdx; - } - params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx; - } - } - } -} - -template -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) -__global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(NumThreads) - routingIndicesClusterKernel(KernelParams params) { - // number of tokens/expanded idx is bounded by total number of warps - using OutputT = typename KernelParams::OutputT; - using InputT = typename KernelParams::InputT; - - using BaseType = std::conditional_t; - using TypePacked = PackedScoreIdx; - - static constexpr int VecSize = KernelParams::MaxNumExperts / WarpSize; - - __shared__ TypePacked __attribute((aligned(128))) smemPackedScoreIdx[NumWarps * MaxNumTopExperts]; - - uint32_t const clusterBlockRank = blockIdx.x; - - int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); - int32_t const laneIdx = cutlass::arch::LaneId(); - - auto warpTokenIdx = clusterBlockRank * NumWarps + warpIdx; - auto scoreOffset = warpTokenIdx * params.mNumExperts; - bool validToken = warpTokenIdx < params.mNumTokens; - - auto block = cg::this_thread_block(); - auto warp = cg::tiled_partition(block); - - // then wait on primary grid - if constexpr (KernelParams::UsePdl) { - cudaGridDependencySynchronize(); - } - - if (params.mPtrScores != nullptr) { - // in this case, each warp represents a token - BaseType score[VecSize]; - int32_t idx[VecSize]; - - BaseType warpTopKScore[MaxNumTopExperts]; - int32_t warpTopKExpertIdx[MaxNumTopExperts]; - - BaseType minScore = BaseType{-INFINITY}; - if (validToken) { - routingTopKExperts( - warp, score, idx, warpTopKScore, warpTopKExpertIdx, laneIdx, params.mNumExperts, - params.mTopK, params.mPtrScores + scoreOffset, params.mNormTopkProb, - params.mApplySoftmaxAfterTopK); - - if (laneIdx < params.mTopK) { - smemPackedScoreIdx[warpIdx * params.mTopK + laneIdx] = - TypePacked{warpTopKScore[laneIdx], static_cast(warpTopKExpertIdx[laneIdx])}; - } - } // end if (validToken) - } - - // make packed scores available to all threads in cluster - __cluster_barrier_arrive(); - __cluster_barrier_wait(); - - if (params.mPtrScores != nullptr) { - routingPermutation(params, smemPackedScoreIdx, warpIdx, - clusterBlockRank); - } else { - routingPermutation(params, smemPackedScoreIdx, warpIdx, - clusterBlockRank); - } -} -#else -__global__ void __launch_bounds__(NumThreads) - routingIndicesClusterKernel(KernelParams /* params */) { - assert(false && "routingIndicesClusterKernel is only supported on SM90+ architectures"); -} -#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// this kernel is needed in case we have scores as input for the histogram kernel -template -__global__ void __launch_bounds__(KernelParams::MaxNumExperts) - routingIndicesHistogramScoresKernel(KernelParams params) { - using OutputT = typename KernelParams::OutputT; - using InputT = typename KernelParams::InputT; - using BaseType = std::conditional_t; - - static constexpr int VecSize = KernelParams::MaxNumExperts / WarpSize; - - int32_t const laneIdx = cutlass::arch::LaneId(); - int32_t const warpIdx = threadIdx.x / WarpSize; - int32_t const globalWarpIdx = blockIdx.x * KernelParams::MaxNumExperts / WarpSize + warpIdx; - int32_t const globalWarpStride = gridDim.x * KernelParams::MaxNumExperts / WarpSize; - BaseType minScore = BaseType{-INFINITY}; - auto block = cg::this_thread_block(); - auto warp = cg::tiled_partition(block); - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - // Wait on primary grid. - if constexpr (KernelParams::UsePdl) { - cudaGridDependencySynchronize(); - } -#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - - // initialize the mPtrExpertCounts - int32_t expertCountsNum = 2 * params.mNumExperts; - int32_t globalThreadIdx = blockIdx.x * KernelParams::MaxNumExperts + threadIdx.x; - int32_t globalThreadStride = gridDim.x * KernelParams::MaxNumExperts; - initArr(globalThreadIdx, expertCountsNum, globalThreadStride, params.mPtrExpertCounts, 0); - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - // Trigger secondary kernel. - if constexpr (KernelParams::UsePdl) { - cudaTriggerProgrammaticLaunchCompletion(); - } -#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - - // in this case, each warp represents a token, and we use a grid-stride loop - // over all warps/tokens - BaseType allScores[VecSize]; - int32_t allExpertIdx[VecSize]; - BaseType warpTopKScore[MaxNumTopExperts]; - int32_t warpTopKExpertIdx[MaxNumTopExperts]; - for (int tokenIdx = globalWarpIdx; tokenIdx < params.mNumTokens; tokenIdx += globalWarpStride) { - auto scoreOffset = tokenIdx * params.mNumExperts; - - routingTopKExperts( - warp, allScores, allExpertIdx, warpTopKScore, warpTopKExpertIdx, laneIdx, - params.mNumExperts, params.mTopK, params.mPtrScores + scoreOffset, params.mNormTopkProb, - params.mApplySoftmaxAfterTopK); - - if (laneIdx < params.mTopK) { - PackedScoreIdx packedScore{static_cast(warpTopKScore[laneIdx]), - static_cast(warpTopKExpertIdx[laneIdx])}; - params.mPtrTopKPacked[tokenIdx * params.mTopK + laneIdx] = packedScore; - } - } -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -int32_t constexpr getMaxNumExperts(int32_t numExperts) { - if (numExperts <= topk::MaxNumExpertsUnit) { - return topk::MaxNumExpertsUnit; - } else if (numExperts <= NumExpertsLimit) { - return NumExpertsLimit; - } else { - TLLM_LOG_ERROR("Unsupported numExperts"); - return 0; - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define LAUNCH_ROUTING_RENORNALIZE(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag1) \ - if (data.mNumExperts <= topk::MaxNumExpertsUnit) { \ - LAUNCH_ROUTING_WITH_NUM_EXPERTS(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag1, topk::MaxNumExpertsUnit); \ - } else if (data.mNumExperts <= NumExpertsLimit) { \ - LAUNCH_ROUTING_WITH_NUM_EXPERTS(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag1, NumExpertsLimit); \ - } else { \ - TLLM_LOG_ERROR("Unsupported numExperts"); \ - } - -//////////////////////////////////////////////////////////////////////////////////////////////////// -void run(Data const& data, void* stream) { - TVM_FFI_ICHECK(data.mPtrTopKPacked != nullptr || data.mPtrScores != nullptr || - data.mPtrTopKIds != nullptr) - << "Routing kernel requires at least one input parameter"; - if (data.mPtrTopKIds != nullptr) { - TVM_FFI_ICHECK(data.mPtrTopKWeights != nullptr) - << "When mPtrTopKIds is provided, mPtrTopKWeights must also be provided for " - "Renormalize routing."; - } - TVM_FFI_ICHECK(data.mPtrPermutedIdxSize != nullptr && data.mPtrCtaIdxXyToBatchIdx != nullptr && - data.mPtrCtaIdxXyToMnLimit != nullptr && data.mPtrNumNonExitingCtas != nullptr) - << "Llama4 routing kernel expects permuted idx and grouped Gemm launch config buffers"; - TVM_FFI_ICHECK_LE(data.mTopK, MaxNumTopExperts) - << "Routing kernel expects topK experts <= " << MaxNumTopExperts << ", got " << data.mTopK; - TVM_FFI_ICHECK_LE(data.mNumExperts, NumExpertsLimit) - << "Routing kernel expects #experts " << data.mNumExperts << " to be no more than " - << NumExpertsLimit << "."; - TVM_FFI_ICHECK_EQ(data.mNumExperts % 4, 0) - << "Routing kernel expects #experts " << data.mNumExperts << " to be a multiple of 4."; - - // FIXME: routingIndicesBlockKernel breaks the vllm + gpt-oss DeepEP - bool const useSingleBlock = - data.mNumTokens <= BlockKernelMaxNumTokens && data.mPtrTopKPacked == nullptr; - - bool const useSingleCluster = - data.mNumTokens <= ((data.mPtrScores != nullptr || data.mPtrTopKIds != nullptr) - ? MaxNumTokensSingleClusterScores - : MaxNumTokensSingleCluster); - - if (!useSingleCluster && !useSingleBlock) { - TVM_FFI_ICHECK(data.mPtrTopKPacked != nullptr || data.mPtrTopKIds != nullptr) - << "When #tokens is large, `mPtrTopKPacked` or `mPtrTopKIds` is a required input."; - TVM_FFI_ICHECK(data.mPtrExpertCounts != nullptr) - << "When #tokens is large, `mPtrExpertCounts` is a required input."; - } - uint32_t const numThreadsHist = getMaxNumExperts(data.mNumExperts); - if (useSingleBlock) { - //@TODO: For now we use the single block kernel for cases with token number no larger than 4. - // We will future tune this threshold based on the performance. - LAUNCH_ROUTING_RENORNALIZE(data, false, routingIndicesBlockKernel, 1, numThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream, data.mDoSoftmaxBeforeTopK); - } else if (useSingleCluster) { - LAUNCH_ROUTING_RENORNALIZE(data, false, routingIndicesClusterKernel, NumBlocksPerCluster, - NumThreads, - /*smemSize=*/0, // No dynamic smem - stream, data.mDoSoftmaxBeforeTopK); - } else { - uint32_t const expandedIdxSize = data.mNumTokens * data.mTopK; - uint32_t const histogramEltsPerBlock = 8 * numThreadsHist; - uint32_t const offsetEltsPerBlock = NumEltsPerOffsetTilePerThread * numThreadsHist; - - // Limit grid size (all kernels use a grid-stride loop). - uint32_t const maxNumBlocks = 1024; - - int const numBlocksHistogram = std::min( - (expandedIdxSize + histogramEltsPerBlock - 1) / histogramEltsPerBlock, maxNumBlocks); - int const numBlocksOffsets = - std::min((expandedIdxSize + offsetEltsPerBlock - 1) / offsetEltsPerBlock, maxNumBlocks); - - if (data.mPtrScores != nullptr && data.mPtrTopKIds == nullptr) { - LAUNCH_ROUTING_RENORNALIZE(data, false, routingIndicesHistogramScoresKernel, maxNumBlocks, - numThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream, data.mDoSoftmaxBeforeTopK); - } else { - // Reset the global histograms. - LAUNCH_ROUTING_RENORNALIZE(data, false, routingInitExpertCounts, - (2 * data.mNumExperts - 1) / numThreadsHist + 1, numThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream, data.mDoSoftmaxBeforeTopK); - } - LAUNCH_ROUTING_RENORNALIZE(data, false, routingIndicesHistogramKernel, numBlocksHistogram, - numThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream, data.mDoSoftmaxBeforeTopK); - LAUNCH_ROUTING_RENORNALIZE(data, false, routingIndicesOffsetsKernel, numBlocksOffsets, - numThreadsHist, - /*smemSize=*/0, // No dynamic smem - stream, data.mDoSoftmaxBeforeTopK); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace routingRenormalize -} // namespace moe::dev::routing diff --git a/csrc/trtllm_fused_moe_runner.cu b/csrc/trtllm_fused_moe_runner.cu index af48040d0a..4091019efc 100644 --- a/csrc/trtllm_fused_moe_runner.cu +++ b/csrc/trtllm_fused_moe_runner.cu @@ -64,10 +64,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 FLASHINFER_CHECK(topkGroup <= 4, "For DeepSeek routing method, must have topkGroup <= 4"); moe::dev::routing::routingDeepSeek::Data routingData; routingData.mDtypeExpW = - btg::Dtype::Bfloat16; // for DeepSeek, the expW is currently always bfloat16 - routingData.mDtypeBias = dtypeBias; // for DeepSeek, the bias can be bfloat16 or fp32 - - routingData.mDtypeScore = btg::Dtype::Fp32; // for DeepSeek, the score is currently always fp32 + btg::Dtype::Bfloat16; // for DeepSeek, the expW is currently always bfloat16 routingData.mUsePdl = true; // output: diff --git a/flashinfer/jit/cpp_ext.py b/flashinfer/jit/cpp_ext.py index f880ae6774..9611ac001f 100644 --- a/flashinfer/jit/cpp_ext.py +++ b/flashinfer/jit/cpp_ext.py @@ -290,7 +290,7 @@ def generate_ninja_build_for_op( is_cuda = source.suffix == ".cu" object_suffix = ".cuda.o" if is_cuda else ".o" cmd = "cuda_compile" if is_cuda else "compile" - obj_name = source.with_suffix(object_suffix).name + obj_name = f"{source.parent.name}_{source.stem}{object_suffix}" obj = str((output_dir / obj_name).resolve()) objects.append(obj) lines.append(f"build {obj}: {cmd} {source.resolve()}") diff --git a/flashinfer/jit/fused_moe.py b/flashinfer/jit/fused_moe.py index a3240f7353..d37be24e55 100644 --- a/flashinfer/jit/fused_moe.py +++ b/flashinfer/jit/fused_moe.py @@ -261,10 +261,40 @@ def gen_trtllm_gen_fused_moe_sm100_module() -> JitSpec: jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/memoryUtils.cu", jit_env.FLASHINFER_CSRC_DIR / "trtllm_fused_moe_kernel_launcher.cu", jit_env.FLASHINFER_CSRC_DIR / "trtllm_fused_moe_runner.cu", - jit_env.FLASHINFER_CSRC_DIR / "trtllm_fused_moe_routing_deepseek.cu", - jit_env.FLASHINFER_CSRC_DIR / "trtllm_fused_moe_routing_llama4.cu", - jit_env.FLASHINFER_CSRC_DIR / "trtllm_fused_moe_routing_renormalize.cu", - jit_env.FLASHINFER_CSRC_DIR / "trtllm_fused_moe_dev_kernel.cu", + jit_env.FLASHINFER_CSRC_DIR + / "fused_moe/trtllm_backend/trtllm_fused_moe_routing_llama4.cu", + # DeepSeek routing (split files) + jit_env.FLASHINFER_CSRC_DIR + / "fused_moe/trtllm_backend/trtllm_fused_moe_routing_deepseek.cu", + jit_env.FLASHINFER_CSRC_DIR + / "fused_moe/trtllm_backend/routingDeepSeek/launchMainKernel.cu", + jit_env.FLASHINFER_CSRC_DIR + / "fused_moe/trtllm_backend/routingDeepSeek/launchClusterKernel.cu", + jit_env.FLASHINFER_CSRC_DIR + / "fused_moe/trtllm_backend/routingDeepSeek/launchCoopKernel.cu", + jit_env.FLASHINFER_CSRC_DIR + / "fused_moe/trtllm_backend/routingDeepSeek/launchHistogramKernel.cu", + jit_env.FLASHINFER_CSRC_DIR + / "fused_moe/trtllm_backend/routingDeepSeek/launchInitExpertCounts.cu", + jit_env.FLASHINFER_CSRC_DIR + / "fused_moe/trtllm_backend/routingDeepSeek/launchOffsetsKernel.cu", + # Renormalize routing (split files) + jit_env.FLASHINFER_CSRC_DIR + / "fused_moe/trtllm_backend/trtllm_fused_moe_routing_renormalize.cu", + jit_env.FLASHINFER_CSRC_DIR + / "fused_moe/trtllm_backend/routingRenormalize/launchBlockKernel.cu", + jit_env.FLASHINFER_CSRC_DIR + / "fused_moe/trtllm_backend/routingRenormalize/launchClusterKernel.cu", + jit_env.FLASHINFER_CSRC_DIR + / "fused_moe/trtllm_backend/routingRenormalize/launchHistogramKernel.cu", + jit_env.FLASHINFER_CSRC_DIR + / "fused_moe/trtllm_backend/routingRenormalize/launchHistogramScoresKernel.cu", + jit_env.FLASHINFER_CSRC_DIR + / "fused_moe/trtllm_backend/routingRenormalize/launchInitExpertCounts.cu", + jit_env.FLASHINFER_CSRC_DIR + / "fused_moe/trtllm_backend/routingRenormalize/launchOffsetsKernel.cu", + jit_env.FLASHINFER_CSRC_DIR + / "fused_moe/trtllm_backend/trtllm_fused_moe_dev_kernel.cu", jit_env.FLASHINFER_CSRC_DIR / "trtllm_batched_gemm_runner.cu", ], extra_cuda_cflags=[ diff --git a/flashinfer/jit/moe_utils.py b/flashinfer/jit/moe_utils.py index 9e0cc9e19c..bff6bb5e06 100644 --- a/flashinfer/jit/moe_utils.py +++ b/flashinfer/jit/moe_utils.py @@ -76,7 +76,20 @@ def gen_moe_utils_module() -> JitSpec: jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/tllmException.cpp", jit_env.FLASHINFER_CSRC_DIR / "nv_internal/cpp/common/memoryUtils.cu", # Routing kernels for moe_sort - jit_env.FLASHINFER_CSRC_DIR / "trtllm_fused_moe_routing_deepseek.cu", + jit_env.FLASHINFER_CSRC_DIR + / "fused_moe/trtllm_backend/trtllm_fused_moe_routing_deepseek.cu", + jit_env.FLASHINFER_CSRC_DIR + / "fused_moe/trtllm_backend/routingDeepSeek/launchMainKernel.cu", + jit_env.FLASHINFER_CSRC_DIR + / "fused_moe/trtllm_backend/routingDeepSeek/launchClusterKernel.cu", + jit_env.FLASHINFER_CSRC_DIR + / "fused_moe/trtllm_backend/routingDeepSeek/launchCoopKernel.cu", + jit_env.FLASHINFER_CSRC_DIR + / "fused_moe/trtllm_backend/routingDeepSeek/launchHistogramKernel.cu", + jit_env.FLASHINFER_CSRC_DIR + / "fused_moe/trtllm_backend/routingDeepSeek/launchInitExpertCounts.cu", + jit_env.FLASHINFER_CSRC_DIR + / "fused_moe/trtllm_backend/routingDeepSeek/launchOffsetsKernel.cu", ], extra_cuda_cflags=nvcc_flags, extra_include_paths=[ diff --git a/include/flashinfer/trtllm/common/cudaUtils.h b/include/flashinfer/trtllm/common/cudaUtils.h index d10c40550a..3f7fa122e5 100644 --- a/include/flashinfer/trtllm/common/cudaUtils.h +++ b/include/flashinfer/trtllm/common/cudaUtils.h @@ -269,4 +269,11 @@ inline __device__ float2 operator*(float2 a, float b) { return make_float2(a.x * inline __device__ float2 operator+(float2 a, float b) { return make_float2(a.x + b, a.y + b); } inline __device__ float2 operator-(float2 a, float b) { return make_float2(a.x - b, a.y - b); } +inline int getMultiProcessorCount() { + int device, count; + cudaGetDevice(&device); + cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, device); + return count; +} + } // namespace tensorrt_llm::common diff --git a/include/flashinfer/trtllm/fused_moe/DevKernel.h b/include/flashinfer/trtllm/fused_moe/DevKernel.h index 560063c023..28c1603bc5 100644 --- a/include/flashinfer/trtllm/fused_moe/DevKernel.h +++ b/include/flashinfer/trtllm/fused_moe/DevKernel.h @@ -49,34 +49,36 @@ namespace moe::dev { #define LAUNCH_ESC(...) __VA_ARGS__ #define LAUNCH_PDL(data, coopLaunch, types, kernel, numBlocks, numThreads, smemSize, stream) \ - cudaLaunchConfig_t config{}; \ - config.gridDim = numBlocks; \ - config.blockDim = numThreads; \ - config.dynamicSmemBytes = smemSize; \ - config.stream = (cudaStream_t)stream; \ + do { \ + cudaLaunchConfig_t config{}; \ + config.gridDim = numBlocks; \ + config.blockDim = numThreads; \ + config.dynamicSmemBytes = smemSize; \ + config.stream = (cudaStream_t)stream; \ \ - cudaLaunchAttribute attributes[2] = {}; \ - attributes[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; \ - attributes[0].val.programmaticStreamSerializationAllowed = int(data.mUsePdl); \ - attributes[1].id = cudaLaunchAttributeCooperative; \ - attributes[1].val.cooperative = int(coopLaunch); \ - config.attrs = attributes; \ - config.numAttrs = 2; \ - if (data.mUsePdl) { \ - auto params = KernelParams::setKernelParams(data); \ - auto kernelTyped = kernel>; \ - if (smemSize > 48 * 1024) \ - CHECK_CUDA_ERROR(cudaFuncSetAttribute( \ - kernelTyped, cudaFuncAttributeMaxDynamicSharedMemorySize, smemSize)); \ - CHECK_CUDA_ERROR(cudaLaunchKernelEx(&config, kernelTyped, params)); \ - } else { \ - auto params = KernelParams::setKernelParams(data); \ - auto kernelTyped = kernel>; \ - if (smemSize > 48 * 1024) \ - CHECK_CUDA_ERROR(cudaFuncSetAttribute( \ - kernelTyped, cudaFuncAttributeMaxDynamicSharedMemorySize, smemSize)); \ - CHECK_CUDA_ERROR(cudaLaunchKernelEx(&config, kernelTyped, params)); \ - } + cudaLaunchAttribute attributes[2] = {}; \ + attributes[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; \ + attributes[0].val.programmaticStreamSerializationAllowed = int(data.mUsePdl); \ + attributes[1].id = cudaLaunchAttributeCooperative; \ + attributes[1].val.cooperative = int(coopLaunch); \ + config.attrs = attributes; \ + config.numAttrs = 2; \ + if (data.mUsePdl) { \ + auto params = KernelParams::setKernelParams(data); \ + auto kernelTyped = kernel>; \ + if (smemSize > 48 * 1024) \ + CHECK_CUDA_ERROR(cudaFuncSetAttribute( \ + kernelTyped, cudaFuncAttributeMaxDynamicSharedMemorySize, smemSize)); \ + CHECK_CUDA_ERROR(cudaLaunchKernelEx(&config, kernelTyped, params)); \ + } else { \ + auto params = KernelParams::setKernelParams(data); \ + auto kernelTyped = kernel>; \ + if (smemSize > 48 * 1024) \ + CHECK_CUDA_ERROR(cudaFuncSetAttribute( \ + kernelTyped, cudaFuncAttributeMaxDynamicSharedMemorySize, smemSize)); \ + CHECK_CUDA_ERROR(cudaLaunchKernelEx(&config, kernelTyped, params)); \ + } \ + } while (0) #define LAUNCH(data, kernel, numBlocks, numThreads, smemSize, stream) \ if (data.mDtypeElt == tg::Dtype::Fp16) { \ @@ -157,94 +159,69 @@ namespace moe::dev { smemSize, stream); \ } -#define LAUNCH_ROUTING_LLAMA4(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream) \ - if (data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, 128 /* Always 128 for llama4*/), \ +#define LAUNCH_ROUTING_LLAMA4(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream) \ + if (data.mDtypeExpW == tg::Dtype::Fp32) { \ + LAUNCH_TILEN( \ + data, coopLaunch, \ + LAUNCH_ESC(float, float, 128 /* Always 128 for llama4*/, 1 /* Always 1 for llama4*/), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeExpW == tg::Dtype::Bfloat16) { \ + LAUNCH_TILEN(data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, 128 /* Always 128 for llama4*/, \ + 1 /* Always 1 for llama4*/), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } else { \ + FLASHINFER_WARN("Unsupported dtypeExpW"); \ + } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define LAUNCH_ROUTING_WITH_NUM_EXPERTS(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + stream, extraFlag1, numExperts, numTopExperts) \ + if (data.mDtypeExpW == tg::Dtype::Fp32 && extraFlag1) { \ + LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, numTopExperts, true), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeExpW == tg::Dtype::Fp32) { \ + LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, numTopExperts, false), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag1) { \ + LAUNCH_TILEN(data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, numTopExperts, true), \ kernel, numBlocks, numThreads, smemSize, stream); \ } else if (data.mDtypeExpW == tg::Dtype::Bfloat16) { \ LAUNCH_TILEN(data, coopLaunch, \ - LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, 128 /* Always 128 for llama4*/), kernel, \ - numBlocks, numThreads, smemSize, stream); \ + LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, numTopExperts, false), \ + kernel, numBlocks, numThreads, smemSize, stream); \ } else { \ FLASHINFER_WARN("Unsupported dtypeExpW"); \ } -#define LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, \ - smemSize, stream, extraFlag, numExperts, \ - numTopExperts) \ - if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Fp32 && \ - data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_TILEN(data, coopLaunch, \ - LAUNCH_ESC(float, float, float, numExperts, numTopExperts, extraFlag), kernel, \ - numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Fp32 && \ - data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_TILEN(data, coopLaunch, \ - LAUNCH_ESC(float, float, __nv_bfloat16, numExperts, numTopExperts, extraFlag), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ - data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_TILEN(data, coopLaunch, \ - LAUNCH_ESC(float, __nv_bfloat16, float, numExperts, numTopExperts, extraFlag), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ - data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_TILEN( \ - data, coopLaunch, \ - LAUNCH_ESC(float, __nv_bfloat16, __nv_bfloat16, numExperts, numTopExperts, extraFlag), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Fp32 && \ - data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_TILEN(data, coopLaunch, \ - LAUNCH_ESC(__nv_bfloat16, float, float, numExperts, numTopExperts, extraFlag), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Fp32 && \ - data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_TILEN( \ - data, coopLaunch, \ - LAUNCH_ESC(__nv_bfloat16, float, __nv_bfloat16, numExperts, numTopExperts, extraFlag), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ - data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_TILEN( \ - data, coopLaunch, \ - LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, float, numExperts, numTopExperts, extraFlag), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ - data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_TILEN(data, coopLaunch, \ - LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, numExperts, \ - numTopExperts, extraFlag), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } else { \ - FLASHINFER_WARN("Unsupported dtypeExpW"); \ - } - -#define LAUNCH_ROUTING_DEEPSEEK_IMPL(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag, numExperts, numTopExperts) \ - if (extraFlag) { \ - LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, \ - smemSize, stream, true, numExperts, numTopExperts); \ - } else { \ - LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, \ - smemSize, stream, false, numExperts, numTopExperts); \ - } - //////////////////////////////////////////////////////////////////////////////////////////////////// -#define LAUNCH_ROUTING_WITH_NUM_EXPERTS(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag1, numExperts) \ - if (data.mDtypeExpW == tg::Dtype::Fp32 && extraFlag1) { \ - LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, true), kernel, numBlocks, \ - numThreads, smemSize, stream); \ +#define LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT( \ + data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag, forceFloatInput, \ + numExperts, numTopExperts) \ + if (data.mDtypeExpW == tg::Dtype::Fp32 && extraFlag) { \ + LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, numTopExperts, true), \ + kernel, numBlocks, numThreads, smemSize, stream); \ } else if (data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, false), kernel, numBlocks, \ - numThreads, smemSize, stream); \ - } else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag1) { \ - LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, true), \ + LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, numTopExperts, false), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag && forceFloatInput) { \ + LAUNCH_TILEN(data, coopLaunch, \ + LAUNCH_ESC(float, __nv_bfloat16, numExperts, numTopExperts, true), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag) { \ + LAUNCH_TILEN(data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, numTopExperts, true), \ kernel, numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && forceFloatInput) { \ + LAUNCH_TILEN(data, coopLaunch, \ + LAUNCH_ESC(float, __nv_bfloat16, numExperts, numTopExperts, false), kernel, \ + numBlocks, numThreads, smemSize, stream); \ } else if (data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, false), \ + LAUNCH_TILEN(data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, numTopExperts, false), \ kernel, numBlocks, numThreads, smemSize, stream); \ } else { \ FLASHINFER_WARN("Unsupported dtypeExpW"); \ diff --git a/include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh b/include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh index 17143ab8a4..d1080a457f 100644 --- a/include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh +++ b/include/flashinfer/trtllm/fused_moe/RoutingKernel.cuh @@ -151,18 +151,21 @@ __device__ void calcSoftmax(cg::thread_block_tile const& warp, template __device__ DataType calcSoftmax(cg::thread_block_tile const& warp, DataType score, int32_t laneIdx, int32_t NumTopExperts) { - DataType maxScore = DataType{-INFINITY}; + // Compute max in float to support half/bfloat16 inputs safely. + // cg::reduce with cg::greater only supports float/double and integer types; + // using __nv_bfloat16 or __half directly can generate unsupported redux.sync.max instructions. + float maxScore = -INFINITY; if (laneIdx < NumTopExperts) { - maxScore = score >= maxScore ? score : maxScore; + float si = static_cast(score); + maxScore = si >= maxScore ? si : maxScore; } - maxScore = cg::reduce(warp, maxScore, cg::greater()); + maxScore = cg::reduce(warp, maxScore, cg::greater()); float sumScore = float{0.f}; - float newScore; - // Get the summation of scores for each token + float newScore = 0.f; if (laneIdx < NumTopExperts) { - newScore = static_cast(score) - static_cast(maxScore); - newScore = static_cast(exp(newScore)); + newScore = static_cast(score) - maxScore; + newScore = expf(newScore); sumScore += newScore; } sumScore = cg::reduce(warp, sumScore, cg::plus()); @@ -183,46 +186,40 @@ __device__ void routingPermutation(KernelParams params, using OutputT = typename KernelParams::OutputT; using TypePacked = PackedScoreIdx; + static constexpr int MaxNumExperts = KernelParams::MaxNumExperts; + static constexpr int ExpertsPerThread = + MaxNumExperts <= NumThreads ? 1 : MaxNumExperts / NumThreads; + static_assert(MaxNumExperts <= NumThreads || MaxNumExperts % NumThreads == 0, + "MaxNumExperts must be <= NumThreads or a multiple of NumThreads"); + static constexpr int MaxNumTokensSingleCluster = NumBlocksPerCluster * NumThreads; - // Number of threads in the cluster. static constexpr int NumThreadsPerCluster = NumThreads * NumBlocksPerCluster; - // same as max num tokens static constexpr int MaxExpandedIdxPerThread = (MaxNumTokensSingleCluster * MaxNumTopExperts + NumThreadsPerCluster - 1) / NumThreadsPerCluster; - // Needed for the exclusive sum of token offsets. - // Note: the scan might include more bins than needed, with bin counts of 0 to pad using Scan = cub::BlockScan; __shared__ typename Scan::TempStorage tempStorage; uint32_t const clusterThreadIdx = NumThreads * clusterBlockRank + threadIdx.x; auto expandedIdxSize = params.mNumTokens * params.mTopK; - // number of experts is bounded by number of threads - __shared__ int32_t __attribute((aligned(128))) smemExpertCount[NumThreads]; - __shared__ int32_t __attribute((aligned(128))) smemExpertOffset[NumThreads]; + __shared__ int32_t __attribute((aligned(128))) smemExpertCount[MaxNumExperts]; + __shared__ int32_t __attribute((aligned(128))) smemExpertOffset[MaxNumExperts]; - // pre-fill the counts with 0 - if (threadIdx.x < params.mNumExperts) { - smemExpertCount[threadIdx.x] = 0; +#pragma unroll + for (int e = 0; e < ExpertsPerThread; e++) { + int expert = threadIdx.x * ExpertsPerThread + e; + if (expert < params.mNumExperts) { + smemExpertCount[expert] = 0; + } } __syncthreads(); - // each thread keeps some number of "expanded indexes" assigned to it - // note that expanded indexes simply represent tokens here. - // for each of these, we keep the associated expert and offset within expert in registers int32_t expertIndexes[MaxExpandedIdxPerThread]; int32_t expertOffsets[MaxExpandedIdxPerThread]; auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2; - // In order to avoid a serialization LDG-ATOMS-LDG-ATOMS-..., we skip multiple iterations at a - // time, and branch between a fast path without bound checks and a slow path with bound checks. - // TODO(mjoux): potentially add this back for perf tuning - // int constexpr IterStride = 4; - // static_assert(MaxExpandedIdxPerThread % IterStride == 0); - - // Define a lambda to avoid code duplication in both branches. auto loopBody = [&](int ii, int expandedIdx) { TypePacked scoreIdx; if constexpr (LoadExpertIdxFromGlobal) { @@ -240,10 +237,9 @@ __device__ void routingPermutation(KernelParams params, } expertIndexes[ii] = scoreIdx.idx; - // check whether this expert is local to our GPU at all and ignore if not auto localExpertIdx = scoreIdx.idx - params.mLocalExpertsStartIdx; auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent && - (localExpertIdx & params.mLocalExpertsStrideLog2) == 0; + (localExpertIdx & ((1 << params.mLocalExpertsStrideLog2) - 1)) == 0; expertOffsets[ii] = isLocalExpert ? atomicAdd(smemExpertCount + scoreIdx.idx, 1) : 0; if (params.mPtrTopKWeights != nullptr && params.mPtrTopKIds == nullptr) { params.mPtrTopKWeights[expandedIdx] = OutputT{scoreIdx.score}; @@ -253,7 +249,6 @@ __device__ void routingPermutation(KernelParams params, int constexpr IterStride = 4; #pragma unroll for (int32_t ii0 = 0; ii0 < MaxExpandedIdxPerThread; ii0 += IterStride) { - // Whether it's safe to do multiple iterations without bound checks. bool const takeFastPath = (ii0 + IterStride) * NumThreadsPerCluster <= expandedIdxSize; if (takeFastPath) { #pragma unroll @@ -279,86 +274,82 @@ __device__ void routingPermutation(KernelParams params, } } } - // Make local histogram (token counts per expert) available to all threads in the cluster. __cluster_barrier_arrive(); __cluster_barrier_wait(); - // - // Each thread now represents one expert - // + int32_t count[ExpertsPerThread]; + int32_t blockExpertOffset[ExpertsPerThread]; - // Total number of tokens for this expert. - int32_t count = 0; - // Per-expert offset for this block. - int32_t blockExpertOffset = 0; +#pragma unroll + for (int e = 0; e < ExpertsPerThread; e++) { + int expert = threadIdx.x * ExpertsPerThread + e; + count[e] = 0; + blockExpertOffset[e] = 0; - if (threadIdx.x < params.mNumExperts) { - // Get the histogram bin from each rank for this expert. - int32_t expertCounts[NumBlocksPerCluster]; + if (expert < params.mNumExperts) { + int32_t expertCounts[NumBlocksPerCluster]; #pragma unroll - for (int rank = 0; rank < NumBlocksPerCluster; rank++) { - int32_t const* remoteSmem = cg::cluster_group::map_shared_rank(smemExpertCount, rank); - expertCounts[rank] = rank * NumWarps < params.mNumTokens ? remoteSmem[threadIdx.x] : 0; - } + for (int rank = 0; rank < NumBlocksPerCluster; rank++) { + int32_t const* remoteSmem = cg::cluster_group::map_shared_rank(smemExpertCount, rank); + expertCounts[rank] = rank * NumWarps < params.mNumTokens ? remoteSmem[expert] : 0; + } - // Compute an exclusive prefix sum of the block-local count. #pragma unroll - for (int rank = 0; rank < NumBlocksPerCluster; rank++) { - if (rank == clusterBlockRank) { - blockExpertOffset = count; + for (int rank = 0; rank < NumBlocksPerCluster; rank++) { + if (rank == clusterBlockRank) { + blockExpertOffset[e] = count[e]; + } + count[e] += expertCounts[rank]; } - count += expertCounts[rank]; } } - // Arrive: we do not access distributed shared memory after this point. __cluster_barrier_arrive(); - // Compute the runtime config for projections - // Whether or not an expert is local is taken into account when smemExpertCount is computed - // so we do not need to take it into account here. - - int32_t numCta; - if constexpr (KernelParams::isPow2) { - numCta = divUpLog2(count, params.mPaddingLog2); - } else { - numCta = divUpTileN(count, params.mTileTokensDim); + int32_t numCta[ExpertsPerThread]; +#pragma unroll + for (int e = 0; e < ExpertsPerThread; e++) { + if constexpr (KernelParams::isPow2) { + numCta[e] = divUpLog2(count[e], params.mPaddingLog2); + } else { + numCta[e] = divUpTileN(count[e], params.mTileTokensDim); + } } - int32_t ctaOffset; + int32_t ctaOffset[ExpertsPerThread]; int32_t numNonExitingCtas; Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas); - if (threadIdx.x < params.mNumExperts) { - // Strided loop to share this work between blocks. - for (int32_t cta = clusterBlockRank; cta < numCta; cta += NumBlocksPerCluster) { - const int32_t localExpertIdx = - (threadIdx.x - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2; - params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx; - int32_t mnLimit1; - int32_t mnLimit2; +#pragma unroll + for (int e = 0; e < ExpertsPerThread; e++) { + int expert = threadIdx.x * ExpertsPerThread + e; + if (expert < params.mNumExperts) { + for (int32_t cta = clusterBlockRank; cta < numCta[e]; cta += NumBlocksPerCluster) { + const int32_t localExpertIdx = + (expert - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2; + params.mPtrCtaIdxXyToBatchIdx[ctaOffset[e] + cta] = localExpertIdx; + int32_t mnLimit1; + int32_t mnLimit2; + if constexpr (KernelParams::isPow2) { + mnLimit1 = mulLog2(ctaOffset[e] + cta + 1, params.mPaddingLog2); + mnLimit2 = mulLog2(ctaOffset[e], params.mPaddingLog2) + count[e]; + } else { + mnLimit1 = mulTileN(ctaOffset[e] + cta + 1, params.mTileTokensDim); + mnLimit2 = mulTileN(ctaOffset[e], params.mTileTokensDim) + count[e]; + } + params.mPtrCtaIdxXyToMnLimit[ctaOffset[e] + cta] = min(mnLimit1, mnLimit2); + } + + int32_t offset; if constexpr (KernelParams::isPow2) { - mnLimit1 = mulLog2(ctaOffset + cta + 1, params.mPaddingLog2); - mnLimit2 = mulLog2(ctaOffset, params.mPaddingLog2) + count; + offset = mulLog2(ctaOffset[e], params.mPaddingLog2); } else { - mnLimit1 = mulTileN(ctaOffset + cta + 1, params.mTileTokensDim); - mnLimit2 = mulTileN(ctaOffset, params.mTileTokensDim) + count; + offset = mulTileN(ctaOffset[e], params.mTileTokensDim); } - params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = min(mnLimit1, mnLimit2); - } - - // get the padded offset associated with this expert - int32_t offset; - if constexpr (KernelParams::isPow2) { - offset = mulLog2(ctaOffset, params.mPaddingLog2); - } else { - offset = mulTileN(ctaOffset, params.mTileTokensDim); + smemExpertOffset[expert] = offset + blockExpertOffset[e]; } - // write expert offsets to shared - smemExpertOffset[threadIdx.x] = offset + blockExpertOffset; } - // write out padded count if (clusterBlockRank == 0 && warpIdx == NumWarps - 1 && cute::elect_one_sync()) { int32_t permutedIdxSize; if constexpr (KernelParams::isPow2) { @@ -370,28 +361,15 @@ __device__ void routingPermutation(KernelParams params, params.mPtrNumNonExitingCtas[0] = numNonExitingCtas; } - // make expert offsets available to all threads __syncthreads(); - - // Wait: we cannot exit while other blocks may be accessing the current block's shared memory. - // Note: I observed a perf benefit to doing this before the final loop so the compiler can - // implement break with EXIT. __cluster_barrier_wait(); - // trigger the secondary kernel when using PDL - // We can't do it earlier because FC1 depends on the mPtrCtaIdxXyToBatchIdx, - // mPtrCtaIdxXyToMnLimit, mPtrNumNonExitingCtas and mPtrTotalNumPaddedTokens - // TODO: this is not sufficient to ensure visibility in the next kernel! #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) if constexpr (KernelParams::UsePdl) { cudaTriggerProgrammaticLaunchCompletion(); } #endif - // each thread has the same "expanded indexes" assigned to it as above - // at this point, we know the final offsets of experts and the offsets within - // experts, which allows writing the final index values - #pragma unroll for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ++ii) { auto expandedIdx = static_cast(clusterThreadIdx) + ii * NumThreadsPerCluster; @@ -399,10 +377,9 @@ __device__ void routingPermutation(KernelParams params, break; } auto expertIdx = expertIndexes[ii]; - // check whether this expert is local to our GPU at all auto localExpertIdx = static_cast(expertIdx) - params.mLocalExpertsStartIdx; auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent && - (localExpertIdx & params.mLocalExpertsStrideLog2) == 0; + (localExpertIdx & ((1 << params.mLocalExpertsStrideLog2) - 1)) == 0; auto tokenIdx = expandedIdx / params.mTopK; auto permutedIdx = isLocalExpert ? int32_t{smemExpertOffset[expertIdx]} + expertOffsets[ii] : int32_t{-1}; @@ -419,115 +396,111 @@ __device__ void routingPermutation(KernelParams params, } //////////////////////////////////////////////////////////////////////////////////////////////////// -// Two-step approach (if number of tokens exceed limits of what cluster / cooperative launch -// variants can handle): in order to minimize the amount of data to exchange through global memory, -// we will compute the local histograms in smem twice: the first kernel will get us the total number -// of tokens per expert. The second kernel will use the smem and L2 atomics to get corresponding -// element and tile offsets. -// -// Note: the histogram calculation could also be fused with routingMainKernel, but this might be -// inefficient if we have one CTA per token doing a single global atomic. + template -__global__ void __launch_bounds__(KernelParams::MaxNumExperts) +__global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelParams::MaxNumExperts + : 1024) routingIndicesHistogramKernel(KernelParams params) { using OutputT = typename KernelParams::OutputT; + static constexpr int MaxNumExperts = KernelParams::MaxNumExperts; + static constexpr int NumThreadsBlock = MaxNumExperts <= 1024 ? MaxNumExperts : 1024; + static constexpr int ExpertsPerThread = MaxNumExperts / NumThreadsBlock; + static_assert(MaxNumExperts % NumThreadsBlock == 0, + "MaxNumExperts must be a multiple of NumThreadsBlock"); - // number of experts is bounded by number of threads - __shared__ int32_t __attribute((aligned(128))) smemExpertCount[KernelParams::MaxNumExperts]; + __shared__ int32_t __attribute((aligned(128))) smemExpertCount[MaxNumExperts]; - // For unrolling. uint32_t constexpr NumEltsPerThread = 8; - // Pre-fill the counts with 0 - if (threadIdx.x < params.mNumExperts) { - smemExpertCount[threadIdx.x] = 0; +#pragma unroll + for (int e = 0; e < ExpertsPerThread; e++) { + int expert = threadIdx.x * ExpertsPerThread + e; + if (expert < params.mNumExperts) { + smemExpertCount[expert] = 0; + } } __syncthreads(); #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - // Wait on primary grid and trigger secondary kernel. if constexpr (KernelParams::UsePdl) { cudaGridDependencySynchronize(); cudaTriggerProgrammaticLaunchCompletion(); } -#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) +#endif uint32_t const expandedIdxSize = params.mNumTokens * params.mTopK; uint32_t const localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2; - uint32_t const gridBlockOffset = blockIdx.x * KernelParams::MaxNumExperts; - uint32_t const gridStride = gridDim.x * KernelParams::MaxNumExperts; + uint32_t const gridBlockOffset = blockIdx.x * NumThreadsBlock; + uint32_t const gridStride = gridDim.x * NumThreadsBlock; - // Define a lambda to avoid code duplication in branches. auto loopBody = [&](int expandedIdx) { PackedScoreIdx scoreIdx; int idx; if (params.mPtrTopKIds != nullptr) { idx = params.mPtrTopKIds[expandedIdx]; } else { - // If params.mPtrTopKIds != nullptr, we don't need to store the weights if (params.mPtrTopKWeights != nullptr) { scoreIdx = params.mPtrTopKPacked[expandedIdx]; idx = scoreIdx.idx; params.mPtrTopKWeights[expandedIdx] = static_cast(scoreIdx.score); } } - // check whether this expert is local to our GPU at all and ignore if not auto localExpertIdx = idx - params.mLocalExpertsStartIdx; auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent && - (localExpertIdx & params.mLocalExpertsStrideLog2) == 0; + (localExpertIdx & ((1 << params.mLocalExpertsStrideLog2) - 1)) == 0; if (isLocalExpert) { atomicAdd(&smemExpertCount[idx], 1); } }; - // Grid-stride loop. for (uint32_t expandedIdx0 = gridBlockOffset * NumEltsPerThread; expandedIdx0 < expandedIdxSize; expandedIdx0 += gridStride * NumEltsPerThread) { - // Fast path if bound checks aren't necessary - if (expandedIdx0 + NumEltsPerThread * KernelParams::MaxNumExperts <= expandedIdxSize) { + if (expandedIdx0 + NumEltsPerThread * NumThreadsBlock <= expandedIdxSize) { #pragma unroll for (uint32_t ii = 0; ii < NumEltsPerThread; ii++) { - uint32_t expandedIdx = expandedIdx0 + ii * KernelParams::MaxNumExperts + threadIdx.x; + uint32_t expandedIdx = expandedIdx0 + ii * NumThreadsBlock + threadIdx.x; loopBody(expandedIdx); } } else { for (uint32_t expandedIdx = expandedIdx0 + threadIdx.x; expandedIdx < expandedIdxSize; - expandedIdx += KernelParams::MaxNumExperts) { + expandedIdx += NumThreadsBlock) { loopBody(expandedIdx); } } } __syncthreads(); - // - // Each thread now represents one expert - // - - // Reduce histograms with atomics. - if (threadIdx.x < params.mNumExperts) { - int32_t const localExpertCount = smemExpertCount[threadIdx.x]; - atomicAdd(¶ms.mPtrExpertCounts[threadIdx.x], localExpertCount); +#pragma unroll + for (int e = 0; e < ExpertsPerThread; e++) { + int expert = threadIdx.x * ExpertsPerThread + e; + if (expert < params.mNumExperts) { + int32_t const localExpertCount = smemExpertCount[expert]; + atomicAdd(¶ms.mPtrExpertCounts[expert], localExpertCount); + } } } //////////////////////////////////////////////////////////////////////////////////////////////////// template -__global__ void __launch_bounds__(KernelParams::MaxNumExperts) +__global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelParams::MaxNumExperts + : 1024) routingIndicesOffsetsKernel(KernelParams params) { using OutputT = typename KernelParams::OutputT; - - // number of experts is bounded by number of threads - __shared__ int32_t __attribute((aligned(128))) smemExpertOffset[KernelParams::MaxNumExperts]; - __shared__ int32_t __attribute((aligned(128))) smemExpertCount[KernelParams::MaxNumExperts]; - __shared__ int32_t __attribute((aligned(128))) smemExpertTileOffset[KernelParams::MaxNumExperts]; - // needed for the exclusive sum of token offsets - using Scan = cub::BlockScan; + static constexpr int MaxNumExperts = KernelParams::MaxNumExperts; + static constexpr int NumThreadsBlock = MaxNumExperts <= 1024 ? MaxNumExperts : 1024; + static constexpr int ExpertsPerThread = MaxNumExperts / NumThreadsBlock; + static_assert(MaxNumExperts % NumThreadsBlock == 0, + "MaxNumExperts must be a multiple of NumThreadsBlock"); + + __shared__ int32_t __attribute((aligned(128))) smemExpertOffset[MaxNumExperts]; + __shared__ int32_t __attribute((aligned(128))) smemExpertCount[MaxNumExperts]; + __shared__ int32_t __attribute((aligned(128))) smemExpertTileOffset[MaxNumExperts]; + using Scan = cub::BlockScan; __shared__ typename Scan::TempStorage tempStorage; static constexpr int MaxExpandedIdxPerThread = NumEltsPerOffsetTilePerThread; - static constexpr int MaxExpandedIdxPerBlock = - KernelParams::MaxNumExperts * MaxExpandedIdxPerThread; + static constexpr int MaxExpandedIdxPerBlock = NumThreadsBlock * MaxExpandedIdxPerThread; int32_t const warpIdx = __shfl_sync(0xffffffff, threadIdx.x / WarpSize, 0); @@ -536,57 +509,48 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) (expandedIdxSize + MaxExpandedIdxPerBlock - 1) / (MaxExpandedIdxPerBlock); #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - // Wait on primary grid. if constexpr (KernelParams::UsePdl) { cudaGridDependencySynchronize(); } -#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - - // The expert offsets are common to all tiles of all blocks. - // Load the histogram, scan it and write offsets to shared memory. - // Note: the scan is redundant in all CTAs. Would it make sense to use an intermediate kernel for - // the scan, with PDL? - - // - // Each thread represents one expert. - // - - // Get total count for this expert. - int32_t count = (threadIdx.x < params.mNumExperts) ? params.mPtrExpertCounts[threadIdx.x] : 0; - - // Compute the runtime config for projections - // Whether or not an expert is local is taken into account when the histogram is computed - // so we do not need to take it into account here. - // const int32_t numCta = divUpLog2(count, params.mPaddingLog2); - int32_t numCta; - if constexpr (KernelParams::isPow2) { - numCta = divUpLog2(count, params.mPaddingLog2); - } else { - numCta = divUpTileN(count, params.mTileTokensDim); +#endif + + int32_t count[ExpertsPerThread]; +#pragma unroll + for (int e = 0; e < ExpertsPerThread; e++) { + int expert = threadIdx.x * ExpertsPerThread + e; + count[e] = (expert < params.mNumExperts) ? params.mPtrExpertCounts[expert] : 0; } - int32_t ctaOffset; - int32_t numNonExitingCtas; - Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas); - if (threadIdx.x < params.mNumExperts) { - // Get the padded offset associated with this expert - int32_t offset; + int32_t numCta[ExpertsPerThread]; +#pragma unroll + for (int e = 0; e < ExpertsPerThread; e++) { if constexpr (KernelParams::isPow2) { - offset = mulLog2(ctaOffset, params.mPaddingLog2); + numCta[e] = divUpLog2(count[e], params.mPaddingLog2); } else { - offset = mulTileN(ctaOffset, params.mTileTokensDim); + numCta[e] = divUpTileN(count[e], params.mTileTokensDim); } + } + int32_t ctaOffset[ExpertsPerThread]; + int32_t numNonExitingCtas; + Scan(tempStorage).ExclusiveSum(numCta, ctaOffset, numNonExitingCtas); - // Write expert offsets to shared - smemExpertOffset[threadIdx.x] = offset; +#pragma unroll + for (int e = 0; e < ExpertsPerThread; e++) { + int expert = threadIdx.x * ExpertsPerThread + e; + if (expert < params.mNumExperts) { + int32_t offset; + if constexpr (KernelParams::isPow2) { + offset = mulLog2(ctaOffset[e], params.mPaddingLog2); + } else { + offset = mulTileN(ctaOffset[e], params.mTileTokensDim); + } + smemExpertOffset[expert] = offset; + } } - // Sync to make expert offsets available to all threads. __syncthreads(); - // The first block writes out padded count - if (blockIdx.x == 0 && warpIdx == KernelParams::MaxNumExperts / WarpSize - 1 && - cute::elect_one_sync()) { + if (blockIdx.x == 0 && warpIdx == NumThreadsBlock / WarpSize - 1 && cute::elect_one_sync()) { int32_t permutedIdxSize; if constexpr (KernelParams::isPow2) { permutedIdxSize = mulLog2(numNonExitingCtas, params.mPaddingLog2); @@ -597,86 +561,76 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) params.mPtrNumNonExitingCtas[0] = numNonExitingCtas; } - if (threadIdx.x < params.mNumExperts) { - // Strided loop to share this work between blocks. - for (int32_t cta = blockIdx.x; cta < numCta; cta += gridDim.x) { - const int32_t localExpertIdx = - (threadIdx.x - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2; - params.mPtrCtaIdxXyToBatchIdx[ctaOffset + cta] = localExpertIdx; - int32_t mnLimit1; - int32_t mnLimit2; - if constexpr (KernelParams::isPow2) { - mnLimit1 = mulLog2(ctaOffset + cta + 1, params.mPaddingLog2); - mnLimit2 = mulLog2(ctaOffset, params.mPaddingLog2) + count; - } else { - mnLimit1 = mulTileN(ctaOffset + cta + 1, params.mTileTokensDim); - mnLimit2 = mulTileN(ctaOffset, params.mTileTokensDim) + count; +#pragma unroll + for (int e = 0; e < ExpertsPerThread; e++) { + int expert = threadIdx.x * ExpertsPerThread + e; + if (expert < params.mNumExperts) { + for (int32_t cta = blockIdx.x; cta < numCta[e]; cta += gridDim.x) { + const int32_t localExpertIdx = + (expert - params.mLocalExpertsStartIdx) >> params.mLocalExpertsStrideLog2; + params.mPtrCtaIdxXyToBatchIdx[ctaOffset[e] + cta] = localExpertIdx; + int32_t mnLimit1; + int32_t mnLimit2; + if constexpr (KernelParams::isPow2) { + mnLimit1 = mulLog2(ctaOffset[e] + cta + 1, params.mPaddingLog2); + mnLimit2 = mulLog2(ctaOffset[e], params.mPaddingLog2) + count[e]; + } else { + mnLimit1 = mulTileN(ctaOffset[e] + cta + 1, params.mTileTokensDim); + mnLimit2 = mulTileN(ctaOffset[e], params.mTileTokensDim) + count[e]; + } + params.mPtrCtaIdxXyToMnLimit[ctaOffset[e] + cta] = min(mnLimit1, mnLimit2); } - params.mPtrCtaIdxXyToMnLimit[ctaOffset + cta] = min(mnLimit1, mnLimit2); } } - // - // Now loop on indices and compute offsets. - // - - // Grid-stride loop on 1D "tiles" of input indices. for (uint32_t tileIdx = blockIdx.x; tileIdx < numTiles; tileIdx += gridDim.x) { if (tileIdx > 0) { - // Sync for safe reuse of smem buffers. __syncthreads(); } - // Pre-fill the counts with 0 - if (threadIdx.x < params.mNumExperts) { - smemExpertCount[threadIdx.x] = 0; +#pragma unroll + for (int e = 0; e < ExpertsPerThread; e++) { + int expert = threadIdx.x * ExpertsPerThread + e; + if (expert < params.mNumExperts) { + smemExpertCount[expert] = 0; + } } __syncthreads(); - // each thread keeps has some number of "expanded indexes" assigned to it - // for each of these, we keep the associated expert and offset within expert in registers int32_t expertIndexes[MaxExpandedIdxPerThread]; int32_t expertOffsets[MaxExpandedIdxPerThread]; auto localExpertExtent = params.mNumLocalExperts << params.mLocalExpertsStrideLog2; - // Define a lambda to avoid code duplication in branches. auto loopBody = [&](int ii, int expandedIdx) { expertIndexes[ii] = params.mPtrTopKIds ? params.mPtrTopKIds[expandedIdx] : params.mPtrTopKPacked[expandedIdx].idx; - // check whether this expert is local to our GPU at all and ignore if not auto localExpertIdx = expertIndexes[ii] - params.mLocalExpertsStartIdx; auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent && - (localExpertIdx & params.mLocalExpertsStrideLog2) == 0; + (localExpertIdx & ((1 << params.mLocalExpertsStrideLog2) - 1)) == 0; expertOffsets[ii] = isLocalExpert ? atomicAdd(smemExpertCount + expertIndexes[ii], 1) : 0; }; - // For all tiles but the last, all indices are in bounds. if (tileIdx < numTiles - 1) { #pragma unroll for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ii += 1) { - auto expandedIdx = - tileIdx * MaxExpandedIdxPerBlock + ii * KernelParams::MaxNumExperts + threadIdx.x; + auto expandedIdx = tileIdx * MaxExpandedIdxPerBlock + ii * NumThreadsBlock + threadIdx.x; loopBody(ii, expandedIdx); } } else { - // For the last tile, we need to exit the loop when out of bounds. - // In order to avoid a serialization LDG-ATOMS-LDG-ATOMS-..., we skip multiple iterations at a - // time, and branch between a fast path without bound checks and a slow path with bound checks int constexpr IterStride = 4; static_assert(MaxExpandedIdxPerThread % IterStride == 0); #pragma unroll for (int32_t ii0 = 0; ii0 < MaxExpandedIdxPerThread; ii0 += IterStride) { - // Whether it's safe to do multiple iterations without bound checks. bool const takeFastPath = - tileIdx * MaxExpandedIdxPerBlock + (ii0 + IterStride) * KernelParams::MaxNumExperts <= + tileIdx * MaxExpandedIdxPerBlock + (ii0 + IterStride) * NumThreadsBlock <= expandedIdxSize; if (takeFastPath) { #pragma unroll for (int32_t jj = 0; jj < IterStride; jj++) { int const ii = ii0 + jj; auto expandedIdx = - tileIdx * MaxExpandedIdxPerBlock + ii * KernelParams::MaxNumExperts + threadIdx.x; + tileIdx * MaxExpandedIdxPerBlock + ii * NumThreadsBlock + threadIdx.x; loopBody(ii, expandedIdx); } } else { @@ -685,7 +639,7 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) for (int32_t jj = 0; jj < IterStride; jj++) { int const ii = ii0 + jj; auto expandedIdx = - tileIdx * MaxExpandedIdxPerBlock + ii * KernelParams::MaxNumExperts + threadIdx.x; + tileIdx * MaxExpandedIdxPerBlock + ii * NumThreadsBlock + threadIdx.x; if (expandedIdx >= expandedIdxSize) { doBreak = true; break; @@ -699,33 +653,25 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) } } - // Make local histogram (token counts per expert) available to all threads in the block. __syncthreads(); - // - // Each thread now represents one expert - // - - if (threadIdx.x < params.mNumExperts) { - // Add the local bin count to the common bin count and get a per-CTA offset. We use the second - // half of the histogram buffer for this histogram, because the first half already holds the - // reduced histogram from the previous kernel. - int32_t const localExpertCount = smemExpertCount[threadIdx.x]; - int32_t const tileExpertOffset = - atomicAdd(¶ms.mPtrExpertCounts[params.mNumExperts + threadIdx.x], localExpertCount); - - // Make per-expert tile offsets available to all threads in the block. - smemExpertTileOffset[threadIdx.x] = tileExpertOffset + smemExpertOffset[threadIdx.x]; +#pragma unroll + for (int e = 0; e < ExpertsPerThread; e++) { + int expert = threadIdx.x * ExpertsPerThread + e; + if (expert < params.mNumExperts) { + int32_t const localExpertCount = smemExpertCount[expert]; + int32_t const tileExpertOffset = + atomicAdd(¶ms.mPtrExpertCounts[params.mNumExperts + expert], localExpertCount); + smemExpertTileOffset[expert] = tileExpertOffset + smemExpertOffset[expert]; + } } __syncthreads(); - // Add tile offset and element offset and write to global memory. auto storeLoopBody = [&](int ii, int expandedIdx) { int32_t expertIdx = expertIndexes[ii]; - // check whether this expert is local to our GPU at all auto localExpertIdx = static_cast(expertIdx) - params.mLocalExpertsStartIdx; auto isLocalExpert = localExpertIdx >= 0 && localExpertIdx < localExpertExtent && - (localExpertIdx & params.mLocalExpertsStrideLog2) == 0; + (localExpertIdx & ((1 << params.mLocalExpertsStrideLog2) - 1)) == 0; auto tokenIdx = expandedIdx / params.mTopK; auto permutedIdx = isLocalExpert ? (expertOffsets[ii] + smemExpertTileOffset[expertIdx]) : int32_t{-1}; @@ -739,19 +685,16 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) params.mPtrPermutedIdxToTokenIdx[permutedIdx] = tokenIdx; } }; - // Bound checks only in last tile. if (tileIdx < numTiles - 1) { #pragma unroll for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ii += 1) { - auto expandedIdx = - tileIdx * MaxExpandedIdxPerBlock + ii * KernelParams::MaxNumExperts + threadIdx.x; + auto expandedIdx = tileIdx * MaxExpandedIdxPerBlock + ii * NumThreadsBlock + threadIdx.x; storeLoopBody(ii, expandedIdx); } } else { #pragma unroll for (int32_t ii = 0; ii < MaxExpandedIdxPerThread; ii += 1) { - auto expandedIdx = - tileIdx * MaxExpandedIdxPerBlock + ii * KernelParams::MaxNumExperts + threadIdx.x; + auto expandedIdx = tileIdx * MaxExpandedIdxPerBlock + ii * NumThreadsBlock + threadIdx.x; if (expandedIdx >= expandedIdxSize) { break; } @@ -761,40 +704,38 @@ __global__ void __launch_bounds__(KernelParams::MaxNumExperts) } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - // Trigger secondary kernel. - // Note: this does not guarantee the visibility of prior writes unless the consumer executes a - // dependency sync. if constexpr (KernelParams::UsePdl) { cudaTriggerProgrammaticLaunchCompletion(); } -#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) +#endif } //////////////////////////////////////////////////////////////////////////////////////////////////// template -__global__ void __launch_bounds__(KernelParams::MaxNumExperts) +__global__ void __launch_bounds__(KernelParams::MaxNumExperts <= 1024 ? KernelParams::MaxNumExperts + : 1024) routingInitExpertCounts(KernelParams params) { - // initialize the mPtrExpertCounts + static constexpr int NumThreadsBlock = + KernelParams::MaxNumExperts <= 1024 ? KernelParams::MaxNumExperts : 1024; + int32_t expertCountsNum = 2 * params.mNumExperts; - int32_t globalThreadIdx = blockIdx.x * KernelParams::MaxNumExperts + threadIdx.x; - int32_t globalThreadStride = gridDim.x * KernelParams::MaxNumExperts; + int32_t globalThreadIdx = blockIdx.x * NumThreadsBlock + threadIdx.x; + int32_t globalThreadStride = gridDim.x * NumThreadsBlock; #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - // Wait on primary grid. if constexpr (KernelParams::UsePdl) { cudaGridDependencySynchronize(); } -#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) +#endif initArr(globalThreadIdx, expertCountsNum, globalThreadStride, params.mPtrExpertCounts, 0); #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - // Wait on primary grid. if constexpr (KernelParams::UsePdl) { cudaTriggerProgrammaticLaunchCompletion(); } -#endif // if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) +#endif } } // namespace routing } // namespace moe::dev diff --git a/include/flashinfer/trtllm/fused_moe/RoutingKernel.h b/include/flashinfer/trtllm/fused_moe/RoutingKernel.h index ba90742ce0..456bcd7a75 100644 --- a/include/flashinfer/trtllm/fused_moe/RoutingKernel.h +++ b/include/flashinfer/trtllm/fused_moe/RoutingKernel.h @@ -104,11 +104,13 @@ struct DataBase { int32_t mNumLocalExperts; }; -template +template struct KernelParamsBase { using InputT = InputT_; using OutputT = OutputT_; static constexpr int MaxNumExperts = MaxNumExperts_; + static constexpr int MaxNumTopExperts = MaxNumTopExperts_; static constexpr bool isPow2 = isPow2_; static constexpr bool UsePdl = UsePdl_; @@ -166,8 +168,6 @@ namespace routingDeepSeek { //////////////////////////////////////////////////////////////////////////////////////////////////// struct Data : public DataBase { tg::Dtype mDtypeExpW{tg::Dtype::Bfloat16}; - tg::Dtype mDtypeBias{tg::Dtype::Bfloat16}; - tg::Dtype mDtypeScore{tg::Dtype::Fp32}; // // Grouped Gemm Launch Config Buffers // @@ -181,23 +181,18 @@ struct Data : public DataBase { bool mUseRoutingSoftmax; }; -template -struct KernelParams : public KernelParamsBase { +template +struct KernelParams : public KernelParamsBase { using InputT = InputT_; - using BiasT = BiasT_; using OutputT = OutputT_; static constexpr bool UseGroups = UseGroups_; - static constexpr int MaxNumTopExperts = MaxNumTopExperts_; PackedScoreIdx* mPtrTopKPacked = nullptr; - // OutputT* mPtrTopKWeightsFull = nullptr; - // Note: this variable(mPtrTopKWeightsFull) might need to be added back for the low-latency - // kernels for MoE in tllm-gen in the future - - BiasT const* mPtrRoutingBias = nullptr; + OutputT const* mPtrRoutingBias = nullptr; int32_t mNumExpertGroups = 0; int32_t mNumExpertsPerGroup = 0; @@ -211,9 +206,7 @@ struct KernelParams : public KernelParamsBase*)data.mPtrTopKPacked; - - // params.mPtrTopKWeightsFull = static_cast(data.mPtrTopKWeightsFull); - params.mPtrRoutingBias = static_cast(data.mPtrRoutingBias); + params.mPtrRoutingBias = static_cast(data.mPtrRoutingBias); params.mNumExpertGroups = data.mNumExpertGroups; params.mNumExpertsPerGroup = data.mNumExperts / data.mNumExpertGroups; @@ -239,8 +232,10 @@ struct Data : public DataBase { tg::Dtype mDtypeExpW{tg::Dtype::Bfloat16}; }; -template -struct KernelParams : public KernelParamsBase { +template +struct KernelParams : public KernelParamsBase { using InputT = InputT_; using OutputT = OutputT_; @@ -274,12 +269,17 @@ struct Data : public DataBase { bool mDoSoftmaxBeforeTopK{false}; bool mNormTopkProb{true}; // Default value is true for Qwen3 model - bool mApplySoftmaxAfterTopK{false}; + // If true, applies softmax normalization after selecting top-K experts. + // Use this for models that require post-selection normalization (e.g., specific Qwen variants). + // Mutually exclusive with mDoSoftmaxBeforeTopK when both normalization paths are active. + // NOTE: Don't need to use this variable for now. + bool mApplySoftmaxAfterTopK{true}; }; -template -struct KernelParams : public KernelParamsBase { +template +struct KernelParams : public KernelParamsBase { using InputT = InputT_; using OutputT = OutputT_; diff --git a/include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh b/include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh index 42aa877d26..a0b2ab6009 100644 --- a/include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh +++ b/include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh @@ -32,7 +32,7 @@ namespace cg = cooperative_groups; static constexpr int WarpSize = 32; static constexpr int MaxNumExpertsUnit = 128; -static constexpr int MaxNumTopK = 10; +static constexpr int MaxSupportedTopExperts = 32; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -56,13 +56,10 @@ struct TopKRedType { TypeCmp compactTmp; memcpy(&compactTmp, &valueBits, sizeof(valueBits)); compactTmp = (compactTmp << moveBits) | (0xFFFF & (maxIdx - idx)); - // Use 65535 minus idx to give higher priority to elements with smaller indices. return compactTmp; } static __host__ __device__ inline void unpack(TypeExpW& value, int32_t& index, TypeCmp cmp) { - // Since idx is always smaller than 65536 and positive, we can directly use it as the lower 16 - // bits index = maxIdx - static_cast(cmp & 0xFFFF); auto compactTmp = cmp >> moveBits; @@ -103,10 +100,67 @@ struct TopKRedType { topK[J].compVal = pairMin; \ } +template +struct IsPowerOf2 { + static constexpr bool value = (N > 0) && ((N & (N - 1)) == 0); +}; + //////////////////////////////////////////////////////////////////////////////////////////////////// template -struct Sort; +struct Sort { + static_assert(N > 0 && N <= 64, "Sort only supports N in range [1, 64]"); + + static __device__ void run(RedType* topK) { + if constexpr (IsPowerOf2::value) { +#pragma unroll + for (int k = 2; k <= N; k *= 2) { +#pragma unroll + for (int j = k / 2; j > 0; j /= 2) { +#pragma unroll + for (int i = 0; i < N; ++i) { + int ixj = i ^ j; + if (ixj > i) { + if ((i & k) == 0) { + if (topK[i].compVal < topK[ixj].compVal) { + auto tmp = topK[i].compVal; + topK[i].compVal = topK[ixj].compVal; + topK[ixj].compVal = tmp; + } + } else { + if (topK[i].compVal > topK[ixj].compVal) { + auto tmp = topK[i].compVal; + topK[i].compVal = topK[ixj].compVal; + topK[ixj].compVal = tmp; + } + } + } + } + } + } + } else { +#pragma unroll + for (int pass = 0; pass < N; ++pass) { +#pragma unroll + for (int i = 0; i < N - 1; i += 2) { + if (topK[i].compVal < topK[i + 1].compVal) { + auto tmp = topK[i].compVal; + topK[i].compVal = topK[i + 1].compVal; + topK[i + 1].compVal = tmp; + } + } +#pragma unroll + for (int i = 1; i < N - 1; i += 2) { + if (topK[i].compVal < topK[i + 1].compVal) { + auto tmp = topK[i].compVal; + topK[i].compVal = topK[i + 1].compVal; + topK[i + 1].compVal = tmp; + } + } + } + } + } +}; template struct Sort<1, RedType> { @@ -150,24 +204,22 @@ __forceinline__ __device__ void reduceTopK(cg::thread_block_tile const RedType topK{value, idx}; typename RedType::TypeCmp packedMax{}; #pragma unroll - for (int kk = 0; kk < actualK; ++kk) //@todo: check if actualK is correct - { + for (int kk = 0; kk < actualK; ++kk) { topK = kk > 0 && packedMax == topK.compVal ? RedType{minValue, idx} : topK; - // get the next largest value packedMax = topK.reduce(warp); RedType::unpack(out[kk], outIdx[kk], packedMax); } }; template -__forceinline__ __device__ void reduceTopKFunc(cg::thread_block_tile const& warp, - Type (&out)[K], int32_t (&outIdx)[K], - Type (&value)[N], int32_t (&idx)[N], - Type const minValue, int actualK = K) { +__forceinline__ __device__ void reduceTopK(cg::thread_block_tile const& warp, + Type (&out)[K], int32_t (&outIdx)[K], Type (&value)[N], + int32_t (&idx)[N], Type const minValue, + int actualK = K) { static_assert(K > 0, "Top K must have K > 0"); - static_assert(K < WarpSize, "Top K must have K < WarpSize"); + static_assert(K <= WarpSize, "Top K must have K <= WarpSize"); static_assert(N > 0, "Top K must have N > 0"); - static_assert(N < 5, "Only support candidates number less than or equal to 128"); + static_assert(N <= 64, "Only support candidates number less than or equal to 64*32=2048"); using RedType = TopKRedType; RedType topK[N]; #pragma unroll @@ -178,9 +230,7 @@ __forceinline__ __device__ void reduceTopKFunc(cg::thread_block_tile c Sort::run(topK); typename RedType::TypeCmp packedMax{}; -#pragma unroll - for (int kk = 0; kk < actualK; ++kk) //@todo: check if actualK is correct - { + for (int kk = 0; kk < actualK; ++kk) { bool update = kk > 0 && packedMax == topK[0].compVal; #pragma unroll for (int nn = 0; nn < N; ++nn) { @@ -188,64 +238,11 @@ __forceinline__ __device__ void reduceTopKFunc(cg::thread_block_tile c : update ? topK[nn + 1] : topK[nn]; } - // get the next largest value packedMax = topK[0].reduce(warp); RedType::unpack(out[kk], outIdx[kk], packedMax); } }; -template -__forceinline__ __device__ void reduceTopK(cg::thread_block_tile const& warp, - Type (&out)[K], int32_t (&outIdx)[K], Type (&value)[N], - int32_t (&idx)[N], Type const minValue, - int actualK = K) { - static_assert(K > 0, "Top K must have K > 0"); - static_assert(K < WarpSize, "Top K must have K < WarpSize"); - static_assert(N > 0, "Top K must have N > 0"); - static_assert(N <= 16, "Only support candidates number less than or equal to 16*32=512"); - using RedType = TopKRedType; - - if constexpr (N <= 4) { - reduceTopKFunc(warp, out, outIdx, value, idx, minValue, actualK); - } else { - constexpr int numLoops = (N - 1) / 4 + 1; - constexpr int numResults = (numLoops * K - 1) / WarpSize + 1; - - Type topKBufferValue[numResults]; - int32_t topKBufferIdx[numResults]; - int32_t laneIdx = threadIdx.x % WarpSize; - - for (int ii = 0; ii < numResults; ++ii) { - topKBufferValue[ii] = minValue; - topKBufferIdx[ii] = ii * WarpSize - 1; - } - for (int loop = 0; loop < numLoops; ++loop) { - int start = loop * 4; - Type topKValue[K]; - int32_t topKIdx[K]; - Type inValue[4]; - int32_t inIdx[4]; - for (int i = 0; i < 4; ++i) { - inValue[i] = value[start + i]; - inIdx[i] = idx[start + i]; - } - reduceTopKFunc(warp, topKValue, topKIdx, inValue, inIdx, minValue, actualK); - int inOffset = laneIdx % K; - if (laneIdx >= loop * K && laneIdx < (loop + 1) * K) { - topKBufferValue[0] = topKValue[inOffset]; - topKBufferIdx[0] = topKIdx[inOffset]; - } - if (loop == numLoops - 1 && (laneIdx < (numLoops * K - WarpSize))) { - topKBufferValue[1] = topKValue[inOffset]; - topKBufferIdx[1] = topKIdx[inOffset]; - } - } - - reduceTopKFunc(warp, out, outIdx, topKBufferValue, topKBufferIdx, minValue, - actualK); - } -}; - #undef TOPK_SWAP } // namespace topk } // namespace moe::dev::routing diff --git a/tests/moe/test_trtllm_gen_fused_moe.py b/tests/moe/test_trtllm_gen_fused_moe.py index 62f9860644..887472ddbd 100644 --- a/tests/moe/test_trtllm_gen_fused_moe.py +++ b/tests/moe/test_trtllm_gen_fused_moe.py @@ -2611,7 +2611,6 @@ def run_moe_test( # Validation checks assert top_k <= num_experts - assert top_k <= 22 if (top_k_groups is not None) and (n_groups is not None) and (n_groups > 0): assert top_k_groups <= 4 assert num_experts > n_groups @@ -2870,6 +2869,27 @@ def run_moe_test( }, id="Qwen3_next", ), + pytest.param( + { + "num_experts": 2048, + "top_k": 32, + "padding": 8, + "n_groups": None, + "top_k_groups": None, + "routed_scaling": None, + "has_routing_bias": False, + "routing_method_type": RoutingMethodType.Renormalize, + "compatible_moe_impls": [ + FP8BlockScaleMoe, + FP4Moe, + BF16Moe, + MxInt4BlockScaleMoe, + ], + "compatible_intermediate_size": [384], + "enable_autotune": True, + }, + id="RoutingRenormalize_large_experts", + ), ], ) @pytest.mark.parametrize(