Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions csrc/moe_utils_binding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,14 @@ void moe_sort(
routingData.mLocalExpertsStrideLog2 = 0;
routingData.mNumLocalExperts = num_local_experts;

// Fused shared expert fields β€” unused in cute DSL moe_sort path, but must be zero-initialized
// because the routing kernel reads mNumFusedSharedExperts unconditionally (adds it to numExperts
// and topK at lines 576-577 of trtllm_fused_moe_routing_deepseek.cu).
routingData.mNumFusedSharedExperts = 0;
routingData.mSharedExpertTokenOffset = 0;
routingData.mSharedExpertNumTokens = 0;
routingData.mTotalExpertsPerToken = top_k;

// DeepSeekV3 specific parameters
// For moe_sort, we use n_group=1, topk_group=1 since experts are already selected
routingData.mNumExpertGroups = 1;
Expand Down
74 changes: 45 additions & 29 deletions csrc/trtllm_fused_moe_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -256,33 +256,36 @@ class FusedMoeLauncher {
Tensor num_non_exiting_ctas;

void prepare_routing_common() {
int32_t const totalExpertsPerToken = args->top_k + args->num_fused_shared_experts;
int32_t const totalNumExperts = args->num_experts + args->num_fused_shared_experts;

// Allocate routing phase workspace tensors
num_tokens_per_expert = alloc_tensor({args->num_experts}, dl_int32, hidden_states.device());
num_tokens_per_expert = alloc_tensor({totalNumExperts}, dl_int32, hidden_states.device());
int32_t max_num_padded_tokens =
tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxPermutedPaddedCount(
args->num_tokens, args->top_k, args->num_experts, tile_tokens_dim);
args->num_tokens, totalExpertsPerToken, totalNumExperts, tile_tokens_dim);

total_num_padded_tokens = alloc_tensor({1}, dl_int32, hidden_states.device());

expanded_idx_to_permuted_idx =
alloc_tensor({args->num_tokens * args->top_k}, dl_int32, hidden_states.device());
alloc_tensor({args->num_tokens * totalExpertsPerToken}, dl_int32, hidden_states.device());

permuted_idx_to_token_idx =
alloc_tensor({max_num_padded_tokens}, dl_int32, hidden_states.device());

expert_indexes =
alloc_tensor({args->num_tokens, args->top_k}, dl_int32, hidden_states.device());
alloc_tensor({args->num_tokens, totalExpertsPerToken}, dl_int32, hidden_states.device());

// expert_weights allocation should be done by derived class since data type could vary

int64_t const size_of_expert_count_histogram = std::max(args->num_experts * 2, 256 * 2);
int64_t const size_of_expert_count_histogram = std::max(totalNumExperts * 2, 256 * 2);
expert_count_histogram = alloc_tensor({size_of_expert_count_histogram},
dl_int32, // 256 is the max number of threads per block
// and max number of experts
hidden_states.device());

int32_t max_num_ctas = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxNumCtasInBatchDim(
args->num_tokens, args->top_k, args->num_experts, tile_tokens_dim);
args->num_tokens, totalExpertsPerToken, totalNumExperts, tile_tokens_dim);

cta_idx_xy_to_batch_idx = alloc_tensor({max_num_ctas}, dl_int32, hidden_states.device());

Expand Down Expand Up @@ -334,14 +337,17 @@ class FusedMoeLauncher {
this->activation_type, this->use_shuffled_weight, this->weight_layout);
}

int32_t const effectiveTopK = args->top_k + args->num_fused_shared_experts;
int32_t const effectiveLocalExperts = args->local_num_experts + args->num_fused_shared_experts;

if (moe_tactic == -1) {
moe_tactic = moe_runner->getDefaultValidConfigIndex(
args->top_k, args->hidden_size, args->intermediate_size, args->local_num_experts,
args->num_tokens);
moe_tactic = moe_runner->getDefaultValidConfigIndex(effectiveTopK, args->hidden_size,
args->intermediate_size,
effectiveLocalExperts, args->num_tokens);
}
auto valid_cfgs =
moe_runner->getValidConfigIndices(args->top_k, args->hidden_size, args->intermediate_size,
args->local_num_experts, args->num_tokens);
moe_runner->getValidConfigIndices(effectiveTopK, args->hidden_size, args->intermediate_size,
effectiveLocalExperts, args->num_tokens);
auto valid_it = std::find(valid_cfgs.begin(), valid_cfgs.end(), moe_tactic);
FLASHINFER_CHECK(valid_it != valid_cfgs.end(), "Invalid MoE tactic ", moe_tactic,
" for tile_N=", tile_tokens_dim, ". Number of valid tactics for this tile is ",
Expand Down Expand Up @@ -377,8 +383,8 @@ class FusedMoeLauncher {

routing_runner.run(
args->routing_logits, args->routing_bias, args->num_tokens, args->num_experts, args->top_k,
args->n_group, args->topk_group, args->local_expert_offset, args->local_num_experts,
args->routed_scaling_factor, workspace.routing_expert_indexes,
args->num_fused_shared_experts, args->n_group, args->topk_group, args->local_expert_offset,
args->local_num_experts, args->routed_scaling_factor, workspace.routing_expert_indexes,
static_cast<int*>(expert_count_histogram.data_ptr()),
static_cast<int*>(total_num_padded_tokens.data_ptr()),
static_cast<int*>(expanded_idx_to_permuted_idx.data_ptr()),
Expand Down Expand Up @@ -910,12 +916,13 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher {
auto const routing_bias_dtype =
routing_bias.has_value() ? routing_bias.value().dtype() : dl_bfloat16;
mRoutingBiasDtype = routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32;
int32_t const totalExpertsPerToken = args->top_k + args->num_fused_shared_experts;
// Check ndim==2 and size>0 because empty placeholder tensors may have non-null data_ptr
bool has_precomputed_weights = expert_weights.ndim() == 2 && expert_weights.size(0) > 0;
if (!has_precomputed_weights) {
// Allocate expert_weights buffer for routing output
FusedMoeLauncher::expert_weights =
alloc_tensor({args->num_tokens, args->top_k}, dl_bfloat16, hidden_states.device());
FusedMoeLauncher::expert_weights = alloc_tensor({args->num_tokens, totalExpertsPerToken},
dl_bfloat16, hidden_states.device());
workspace.expert_weights = FusedMoeLauncher::expert_weights.data_ptr();
} else {
workspace.expert_weights = const_cast<void*>(expert_weights.data_ptr());
Expand Down Expand Up @@ -946,12 +953,13 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher {
TVM_FFI_ICHECK_EQ(gemm1_weights.dtype(), dl_float8_e4m3fn) << "gemm1_weights must be fp8.";
TVM_FFI_ICHECK_EQ(gemm2_weights.dtype(), dl_float8_e4m3fn) << "gemm2_weights must be fp8.";

int64_t const totalLocalExperts = args->local_num_experts + args->num_fused_shared_experts;
if (quantization_type == Fp8QuantizationType::DeepSeekFp8) {
TVM_FFI_ICHECK_EQ(gemm1_weights_scale.dtype(), dl_float32)
<< "gemm1_weights_scale must be float.";
TVM_FFI_ICHECK_EQ(gemm1_weights_scale.ndim(), 3) << "gemm1_weights_scale must be 3D.";
TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(0), args->local_num_experts)
<< "gemm1_weights_scale has incorrect shape.";
TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(0), totalLocalExperts)
<< "gemm1_weights_scale has incorrect dim 0.";
TVM_FFI_ICHECK_EQ(args->intermediate_size % 128, 0)
<< "intermediate_size must be a multiple of 128.";
TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(1),
Expand All @@ -971,8 +979,8 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher {
TVM_FFI_ICHECK_EQ(gemm2_weights_scale.dtype(), dl_float32)
<< "gemm2_weights_scale must be float.";
TVM_FFI_ICHECK_EQ(gemm2_weights_scale.ndim(), 3) << "gemm2_weights_scale must be 3D.";
TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(0), args->local_num_experts)
<< "gemm2_weights_scale has incorrect shape.";
TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(0), totalLocalExperts)
<< "gemm2_weights_scale has incorrect dim 0.";
TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(1), args->hidden_size / 128)
<< "gemm2_weights_scale has incorrect shape.";
TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(2), args->intermediate_size / 128)
Expand Down Expand Up @@ -1082,8 +1090,9 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher {
// routing runner to use the pre-computed expert indices from workspace.routing_expert_indexes
routing_runner.run(
use_precomputed ? nullptr : args->routing_logits, args->routing_bias, args->num_tokens,
args->num_experts, args->top_k, args->n_group, args->topk_group, args->local_expert_offset,
args->local_num_experts, args->routed_scaling_factor, workspace.routing_expert_indexes,
args->num_experts, args->top_k, args->num_fused_shared_experts, args->n_group,
args->topk_group, args->local_expert_offset, args->local_num_experts,
args->routed_scaling_factor, workspace.routing_expert_indexes,
static_cast<int*>(expert_count_histogram.data_ptr()),
static_cast<int*>(total_num_padded_tokens.data_ptr()),
static_cast<int*>(expanded_idx_to_permuted_idx.data_ptr()),
Expand Down Expand Up @@ -1545,8 +1554,9 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher {

routing_runner.run(
args->routing_logits, args->routing_bias, args->num_tokens, args->num_experts, args->top_k,
args->n_group, args->topk_group, args->local_expert_offset, args->local_num_experts,
args->routed_scaling_factor, static_cast<int*>(expert_indices.data_ptr()),
args->num_fused_shared_experts, args->n_group, args->topk_group, args->local_expert_offset,
args->local_num_experts, args->routed_scaling_factor,
static_cast<int*>(expert_indices.data_ptr()),
static_cast<int*>(expert_count_histogram.data_ptr()),
static_cast<int*>(total_num_padded_tokens.data_ptr()),
static_cast<int*>(expanded_idx_to_permuted_idx.data_ptr()),
Expand Down Expand Up @@ -1779,10 +1789,11 @@ Array<Tensor> trtllm_fp8_block_scale_moe(
Optional<TensorView> routing_bias, TensorView hidden_states, TensorView hidden_states_scale,
TensorView gemm1_weights, TensorView gemm1_weights_scale, TensorView gemm2_weights,
TensorView gemm2_weights_scale, TensorView output, int64_t num_experts, int64_t top_k,
Optional<int64_t> n_group, Optional<int64_t> topk_group, int64_t intermediate_size,
int64_t local_expert_offset, int64_t local_num_experts, Optional<double> routed_scaling_factor,
int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout, bool do_finalize,
bool enable_pdl, Array<int64_t> config_index, Fp8QuantizationType quantization_type) {
Optional<int64_t> num_fused_shared_experts, Optional<int64_t> n_group,
Optional<int64_t> topk_group, int64_t intermediate_size, int64_t local_expert_offset,
int64_t local_num_experts, Optional<double> routed_scaling_factor, int64_t routing_method_type,
bool use_shuffled_weight, int64_t weight_layout, bool do_finalize, bool enable_pdl,
Array<int64_t> config_index, Fp8QuantizationType quantization_type) {
// Basic type validation
auto dtype = hidden_states.dtype();

Expand Down Expand Up @@ -1843,9 +1854,13 @@ Array<Tensor> trtllm_fp8_block_scale_moe(
auto const num_tokens = hidden_states.size(0);
auto const hidden_size = hidden_states.size(1);

int64_t const nFusedShared = num_fused_shared_experts.value_or(0);
int64_t const totalExpertsPerToken = top_k + nFusedShared;
int64_t const totalLocalExperts = local_num_experts + nFusedShared;

auto supported_tile_nums = Fp8BlockScaleLauncher::getSupportedTileNums(quantization_type);
std::set<int32_t> selected_tile_nums =
computeSelectedTileN(supported_tile_nums, num_tokens, top_k, local_num_experts);
std::set<int32_t> selected_tile_nums = computeSelectedTileN(
supported_tile_nums, num_tokens, totalExpertsPerToken, totalLocalExperts);

// Create a map of launchers for each tile size
std::unordered_map<int32_t, std::unique_ptr<Fp8BlockScaleLauncher>> launchers_map;
Expand All @@ -1855,6 +1870,7 @@ Array<Tensor> trtllm_fp8_block_scale_moe(
auto args = std::make_unique<tensorrt_llm::kernels::trtllmgen_moe::MoE::MoERunnerArgs>();
args->num_tokens = num_tokens;
args->num_experts = num_experts;
args->num_fused_shared_experts = nFusedShared;
args->hidden_size = hidden_size;
args->hidden_size_output = args->hidden_size;
args->top_k = top_k;
Expand Down
40 changes: 35 additions & 5 deletions csrc/trtllm_fused_moe_routing_deepseek.cu
Original file line number Diff line number Diff line change
Expand Up @@ -250,17 +250,28 @@ __global__ void routingMainKernel(KernelParams params) {
auto finalScore = OutputT{scoreNorm * params.mRouteScale / redNorm};

// write expert idx out already
auto idxTopK = blockIdx.x * params.mTopK + laneIdx;
auto idxTopK = blockIdx.x * params.mTotalExpertsPerToken + laneIdx;
auto idxShared = blockIdx.x * params.mTotalExpertsPerToken + params.mTopK + laneIdx;
if (laneIdx < params.mTopK && params.mPtrTopKPacked != nullptr) {
PackedScoreIdx<OutputT> packedScore{static_cast<OutputT>(finalScore),
static_cast<int16_t>(expertIdx)};
params.mPtrTopKPacked[idxTopK] = packedScore;
}

if (laneIdx < params.mNumFusedSharedExperts && params.mPtrTopKPacked != nullptr) {
PackedScoreIdx<OutputT> packedScore{static_cast<OutputT>(1.0F),
static_cast<int16_t>(params.mNumExperts + laneIdx)};
params.mPtrTopKPacked[idxShared] = packedScore;
}

if (laneIdx < params.mTopK && params.mPtrTopKWeights != nullptr &&
params.mPtrTopKIds == nullptr) {
params.mPtrTopKWeights[idxTopK] = finalScore;
}

if (laneIdx < params.mNumFusedSharedExperts && params.mPtrTopKWeights != nullptr) {
params.mPtrTopKWeights[idxShared] = static_cast<OutputT>(1.0F);
}
}
}
}
Expand Down Expand Up @@ -561,6 +572,11 @@ void runImpl(Data& data, void* stream) {
FLASHINFER_CHECK(data.mNumLimitedGroups <= MaxNumTopGroups,
"Routing kernel expects <= %d top groups, got %d", MaxNumTopGroups,
data.mNumLimitedGroups);

int const numExperts = data.mNumExperts + data.mNumFusedSharedExperts;
int const topK = data.mTopK + data.mNumFusedSharedExperts;
int const numThreadsHist = getMaxNumExperts(numExperts);

// Test limits according to values passed in launch, see definition of LAUNCH_ROUTING_DEEPSEEK
if (data.mNumExperts <= NumKimiK2Experts) {
FLASHINFER_CHECK(
Expand All @@ -573,6 +589,9 @@ void runImpl(Data& data, void* stream) {
"When NumExperts > NumKimiK2Experts, routing kernel expects topK experts <= %d, got %d",
MaxSupportedTopExperts, data.mTopK);
}
FLASHINFER_CHECK(topK <= MaxSupportedTopExperts,
"Routing kernel expects topK experts <= %d, got %d", MaxSupportedTopExperts,
topK);
FLASHINFER_CHECK(data.mTopK <= WarpSize, "Routing kernel expects top K <= warp size, got %d",
data.mTopK);
FLASHINFER_CHECK(data.mTopK * data.mNumLimitedGroups <= WarpSize,
Expand All @@ -598,14 +617,19 @@ void runImpl(Data& data, void* stream) {
data.mNumExperts / data.mNumExpertGroups <= WarpSize,
"Routing kernel expects #experts per group <= warp size, got %d, data.mNumExpertGroups %d",
data.mNumExperts / data.mNumExpertGroups, data.mNumExpertGroups);

FLASHINFER_CHECK(data.mNumFusedSharedExperts <= WarpSize,
"Number of fused shared experts (%d) must be less than warp size.",
data.mNumFusedSharedExperts);
Comment on lines +621 to +623
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The check for mNumFusedSharedExperts <= WarpSize is currently placed inside the if (data.mNumExpertGroups > 1) block. However, routingMainKernel always assumes that shared experts can be handled by a single warp (using laneIdx), regardless of whether expert groups are used. This check should be moved outside the conditional block to ensure it is always enforced.

Comment on lines +620 to +623
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟑 Minor

fusedSharedExperts <= WarpSize check should be unconditional.

This validation is guarded by if (data.mNumExpertGroups > 1) (line 605), but the fused shared expert writes at lines 261-265 and 272-274 use laneIdx < mNumFusedSharedExperts regardless of expert groups. If mNumExpertGroups <= 1 and mNumFusedSharedExperts > WarpSize, the writes would silently skip some fused experts.

Suggested fix

Move the check out of the if (data.mNumExpertGroups > 1) block:

+  FLASHINFER_CHECK(data.mNumFusedSharedExperts <= WarpSize,
+                   "Number of fused shared experts (%d) must be less than warp size.",
+                   data.mNumFusedSharedExperts);
+
   if (data.mNumExpertGroups > 1) {
     FLASHINFER_CHECK(data.mNumExpertGroups <= MaxNumGroups,
                      ...);
     ...
-
-    FLASHINFER_CHECK(data.mNumFusedSharedExperts <= WarpSize,
-                     "Number of fused shared experts (%d) must be less than warp size.",
-                     data.mNumFusedSharedExperts);
   }
πŸ€– Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/trtllm_fused_moe_routing_deepseek.cu` around lines 616 - 619, The check
ensuring data.mNumFusedSharedExperts <= WarpSize must be unconditional because
fused shared-expert writes use laneIdx < mNumFusedSharedExperts regardless of
expert group count; move the FLASHINFER_CHECK(data.mNumFusedSharedExperts <=
WarpSize, ...) out of the if (data.mNumExpertGroups > 1) block so it always
runs, ensuring data.mNumFusedSharedExperts is validated before any code paths
that use mNumFusedSharedExperts/mNumFusedSharedExperts-induced lane comparisons
or writes (references: data.mNumFusedSharedExperts, WarpSize,
mNumFusedSharedExperts, mNumExpertGroups).

}
FLASHINFER_CHECK(data.mNumExperts % 4 == 0,
"Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts);

int const numBlocks = data.mNumTokens;
int const numThreadsHist = getMaxNumExperts(data.mNumExperts);

bool const useSingleCluster = data.mNumTokens <= 1024;
int numThreadsPerCluster = numThreadsHist * NumBlocksPerCluster;
bool const useSingleCluster =
data.mNumTokens <= 1024 && data.mNumTokens * topK <= numThreadsPerCluster;
if (!useSingleCluster) {
// Reset the global histograms (not used in single-cluster code path).
// Cover both for the cooperative and two-kernel code paths.
Expand All @@ -629,7 +653,7 @@ void runImpl(Data& data, void* stream) {
int const numBlocksCoop = 128;

// Maximum number of tokens supported by the kernel using a cooperative launch.
int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK;
int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / topK;
if (data.mPtrTopKIds == nullptr) {
int const numThreadsMain =
max(data.mNumExpertGroups * WarpSize, getMaxNumExperts(data.mNumExperts));
Expand All @@ -645,6 +669,12 @@ void runImpl(Data& data, void* stream) {
stream, data.mNumExpertGroups > 1);
}

if (data.mNumFusedSharedExperts > 0) {
data.mNumExperts += data.mNumFusedSharedExperts;
data.mTopK += data.mNumFusedSharedExperts;
data.mNumLocalExperts += data.mNumFusedSharedExperts;
}
Comment on lines +672 to +676
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Updating data.mNumExperts and data.mTopK after the first kernel launch (line 656 or 662) leads to several issues:

  1. numThreadsMain (line 655) and the histogram initialization inside routingMainKernel (line 85) use the original routed expert count, meaning the histogram entries for shared experts are never initialized to zero. This can cause garbage values to be used as offsets in subsequent permutation kernels.
  2. The dispatching macro LAUNCH_ROUTING_DEEPSEEK uses data.mNumExperts to select the MaxNumExperts template parameter. If the total expert count (routed + shared) crosses a threshold (e.g., 256 to 257), the first and second launches will use different template instantiations, which is inconsistent.

You should calculate the total expert count and top-k at the beginning of runImpl and ensure that initialization kernels use the total count, while routingMainKernel receives the routed count for its indexing logic.


if (data.mPtrPermutedIdxSize != nullptr) {
if (useSingleCluster) {
LAUNCH_ROUTING_DEEPSEEK(data,
Expand All @@ -659,7 +689,7 @@ void runImpl(Data& data, void* stream) {
/*smemSize=*/0, // No dynamic smem
stream, data.mNumExpertGroups > 1);
} else {
const int32_t expandedIdxSize = data.mNumTokens * data.mTopK;
const int32_t expandedIdxSize = data.mNumTokens * topK;
const int32_t histogramEltsPerBlock = 8 * numThreadsHist;
const int32_t offsetEltsPerBlock = NumEltsPerOffsetTilePerThread * numThreadsHist;

Expand Down
Loading
Loading