Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 33 additions & 14 deletions cpp/tensorrt_llm/kernels/noAuxTcKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,23 @@ TRTLLM_NAMESPACE_BEGIN
namespace kernels
{
static constexpr int WARP_SIZE = 32;
static constexpr int NumNemotronExperts = 512;
static constexpr int NumKimiK2Experts = 384;
static constexpr int NumDeepseekExperts = 256;
static constexpr int MaxSupportedExpertCount = std::max({NumNemotronExperts, NumKimiK2Experts, NumDeepseekExperts});
static constexpr int MaxNumExpertsUnit = 128;
static constexpr int NumTopGroupScores = 2;
static constexpr int MaxNumTopExperts = 8;
static constexpr int DefaultMaxNumTopExperts = 8;
static constexpr int MaxSupportedTopExperts = 22;
static constexpr int MaxNumTopGroups = 4;

static __device__ inline float sigmoid_accurate(float x)
{
return 0.5f * tanhf(0.5f * x) + 0.5f;
}

template <typename InputT, typename BiasT, typename OutputT, typename IdxT, int MaxNumExperts, bool UseGroups>
template <typename InputT, typename BiasT, typename OutputT, typename IdxT, int MaxNumExperts, bool UseGroups,
int MaxNumTopExperts = DefaultMaxNumTopExperts>
__global__ void deepseek_v3_topk_kernel(InputT* scores, OutputT* topkValues, IdxT* topkIndices, BiasT* routingBias,
int64_t const numTokens, int64_t const numGroup, int64_t const topkGroup, int64_t const topk,
int64_t const numExperts, int64_t const numExpertsPerGroup, double const routedScalingFactor)
Expand Down Expand Up @@ -132,7 +136,7 @@ __global__ void deepseek_v3_topk_kernel(InputT* scores, OutputT* topkValues, Idx
/* minValue */ invalidScoreFloat);

// get the final group score and write it to shared
if (laneIdx == 0)
if (warp.thread_rank() == 0)
{
auto groupScore = topExpGroupScores[0] + topExpGroupScores[1];
smemGroupScores[warpIdx] = groupScore;
Expand All @@ -151,22 +155,19 @@ __global__ void deepseek_v3_topk_kernel(InputT* scores, OutputT* topkValues, Idx

reduce_topk::reduceTopK(warp, topGroups, topGroupIdx, groupScore, laneIdx,
/* minValue */ invalidScoreFloat);

// final expert selection: get relevant indexes and scores from shared

#pragma unroll
for (int ii = 0; ii < MaxNumTopGroups; ++ii)
{ // bound of numGroup
auto groupIdx = topGroupIdx[ii];
expertIdxGroup[ii] = groupIdx * numExpertsPerGroup + laneIdx;

expertScoreGroup[ii]
= groupIdx < numGroup && expertSelected ? smemScoreBias[expertIdxGroup[ii]] : invalidScoreFloat;
= (ii < topkGroup) && expertSelected ? smemScoreBias[expertIdxGroup[ii]] : invalidScoreFloat;
}

tensorrt_llm::kernels::reduce_topk::reduceTopK(warp, topScores, topExperts, expertScoreGroup,
expertIdxGroup,
/* minValue */ invalidScoreFloat, topk);
tensorrt_llm::kernels::reduce_topk::reduceTopK(
warp, topScores, topExperts, expertScoreGroup, expertIdxGroup, /* minValue */ invalidScoreFloat, topk);
}
}
else if constexpr (MaxNumExperts > MaxNumExpertsUnit)
Expand Down Expand Up @@ -197,11 +198,16 @@ __global__ void deepseek_v3_topk_kernel(InputT* scores, OutputT* topkValues, Idx
smemInterTopScores[warpIdx * MaxNumTopExperts + laneIdx] = topScores[laneIdx];
smemInterTopExperts[warpIdx * MaxNumTopExperts + laneIdx] = topExperts[laneIdx];
}
else if (laneIdx >= topk && laneIdx < MaxNumTopExperts)
{
smemInterTopScores[warpIdx * MaxNumTopExperts + laneIdx] = invalidScoreFloat;
smemInterTopExperts[warpIdx * MaxNumTopExperts + laneIdx] = MaxNumExperts - 1;
}
}
__syncthreads();
if (warpIdx == 0)
{
int constexpr NumInterTopKPerThread = (NumInterTopK * NumExpertWarps - 1) / WARP_SIZE + 1;
int constexpr NumInterTopKPerThread = (NumInterTopK - 1) / WARP_SIZE + 1;
float intermidiateScore[NumInterTopKPerThread];
int32_t intermidiateExpert[NumInterTopKPerThread];
for (int i = laneIdx; i < NumInterTopKPerThread * WARP_SIZE; i += WARP_SIZE)
Expand Down Expand Up @@ -268,11 +274,11 @@ void invokeNoAuxTc(InputT* scores, BiasT* bias, OutputT* topk_values, IdxT* topk
{

// Check if we can use the optimized deepseek_v3_topk_kernel
bool const is_single_group = (n_group == 1) && (num_experts <= NumKimiK2Experts);
bool const is_single_group = (n_group <= 1) && (num_experts <= MaxSupportedExpertCount);

int64_t const experts_per_group = num_experts / n_group;
bool const is_multi_group = (n_group != 1) && (num_experts <= NumDeepseekExperts)
&& (experts_per_group <= WARP_SIZE) && (experts_per_group * topk_group <= MaxNumExpertsUnit);
bool const is_multi_group = (n_group > 1) && (num_experts <= NumDeepseekExperts) && (experts_per_group <= WARP_SIZE)
&& (experts_per_group * topk_group <= MaxNumExpertsUnit);

if (is_single_group || is_multi_group)
{
Expand All @@ -281,7 +287,20 @@ void invokeNoAuxTc(InputT* scores, BiasT* bias, OutputT* topk_values, IdxT* topk
int num_threads = NumDeepseekExperts;
if (is_single_group)
{
if (num_experts > MaxNumExpertsUnit)
// Special case for Nemotron, which selects top 22 from 512 experts, and 1 group only.
if (num_experts == NumNemotronExperts && n_group == 1 && topk == MaxSupportedTopExperts)
{
kernel_instance = &deepseek_v3_topk_kernel<InputT, BiasT, OutputT, IdxT, NumNemotronExperts, false,
MaxSupportedTopExperts>;
num_threads = NumNemotronExperts;
}
else if (num_experts > NumKimiK2Experts && num_experts <= MaxSupportedExpertCount)
{
kernel_instance
= &deepseek_v3_topk_kernel<InputT, BiasT, OutputT, IdxT, MaxSupportedExpertCount, false>;
num_threads = MaxSupportedExpertCount;
}
else if (num_experts > MaxNumExpertsUnit && num_experts <= NumKimiK2Experts)
{
kernel_instance = &deepseek_v3_topk_kernel<InputT, BiasT, OutputT, IdxT, NumKimiK2Experts, false>;
num_threads = NumKimiK2Experts;
Expand Down
28 changes: 14 additions & 14 deletions cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/DevKernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,37 +182,37 @@ namespace moe::dev
TLLM_LOG_ERROR("Unsupported dtypeExpW"); \
}

#define LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT( \
data, coopLaunch, kernel, numBlocks, numThreads, smemSize, stream, extraFlag, forceFloatInput, numExperts) \
#define LAUNCH_ROUTING_WITH_NUM_EXPERTS_FORCE_FLOAT_INPUT(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \
stream, extraFlag, forceFloatInput, numExperts, numTopExperts) \
if (data.mDtypeExpW == tg::Dtype::Fp32 && extraFlag) \
{ \
LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, true), kernel, numBlocks, numThreads, \
smemSize, stream); \
LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, numTopExperts, true), kernel, numBlocks, \
numThreads, smemSize, stream); \
} \
else if (data.mDtypeExpW == tg::Dtype::Fp32) \
{ \
LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, false), kernel, numBlocks, numThreads, \
smemSize, stream); \
LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, numExperts, numTopExperts, false), kernel, numBlocks, \
numThreads, smemSize, stream); \
} \
else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag && forceFloatInput) \
{ \
LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, numExperts, true), kernel, numBlocks, \
numThreads, smemSize, stream); \
LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, numExperts, numTopExperts, true), kernel, \
numBlocks, numThreads, smemSize, stream); \
} \
else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && extraFlag) \
{ \
LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, true), kernel, numBlocks, \
numThreads, smemSize, stream); \
LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, numTopExperts, true), \
kernel, numBlocks, numThreads, smemSize, stream); \
} \
else if (data.mDtypeExpW == tg::Dtype::Bfloat16 && forceFloatInput) \
{ \
LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, numExperts, false), kernel, numBlocks, \
numThreads, smemSize, stream); \
LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, numExperts, numTopExperts, false), kernel, \
numBlocks, numThreads, smemSize, stream); \
} \
else if (data.mDtypeExpW == tg::Dtype::Bfloat16) \
{ \
LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, false), kernel, numBlocks, \
numThreads, smemSize, stream); \
LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, numExperts, numTopExperts, false), \
kernel, numBlocks, numThreads, smemSize, stream); \
} \
else \
{ \
Expand Down
Loading