diff --git a/csrc/moe_utils_binding.cu b/csrc/moe_utils_binding.cu index 8cfd00f3eb..c46bce72d0 100644 --- a/csrc/moe_utils_binding.cu +++ b/csrc/moe_utils_binding.cu @@ -329,6 +329,14 @@ void moe_sort( routingData.mLocalExpertsStrideLog2 = 0; routingData.mNumLocalExperts = num_local_experts; + // Fused shared expert fields — unused in cute DSL moe_sort path, but must be zero-initialized + // because the routing kernel reads mNumFusedSharedExperts unconditionally (adds it to numExperts + // and topK at lines 576-577 of trtllm_fused_moe_routing_deepseek.cu). + routingData.mNumFusedSharedExperts = 0; + routingData.mSharedExpertTokenOffset = 0; + routingData.mSharedExpertNumTokens = 0; + routingData.mTotalExpertsPerToken = top_k; + // DeepSeekV3 specific parameters // For moe_sort, we use n_group=1, topk_group=1 since experts are already selected routingData.mNumExpertGroups = 1; diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index 64fece5021..33d3566ff1 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -256,33 +256,36 @@ class FusedMoeLauncher { Tensor num_non_exiting_ctas; void prepare_routing_common() { + int32_t const totalExpertsPerToken = args->top_k + args->num_fused_shared_experts; + int32_t const totalNumExperts = args->num_experts + args->num_fused_shared_experts; + // Allocate routing phase workspace tensors - num_tokens_per_expert = alloc_tensor({args->num_experts}, dl_int32, hidden_states.device()); + num_tokens_per_expert = alloc_tensor({totalNumExperts}, dl_int32, hidden_states.device()); int32_t max_num_padded_tokens = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxPermutedPaddedCount( - args->num_tokens, args->top_k, args->num_experts, tile_tokens_dim); + args->num_tokens, totalExpertsPerToken, totalNumExperts, tile_tokens_dim); total_num_padded_tokens = alloc_tensor({1}, dl_int32, hidden_states.device()); expanded_idx_to_permuted_idx = - alloc_tensor({args->num_tokens * args->top_k}, dl_int32, hidden_states.device()); + alloc_tensor({args->num_tokens * totalExpertsPerToken}, dl_int32, hidden_states.device()); permuted_idx_to_token_idx = alloc_tensor({max_num_padded_tokens}, dl_int32, hidden_states.device()); expert_indexes = - alloc_tensor({args->num_tokens, args->top_k}, dl_int32, hidden_states.device()); + alloc_tensor({args->num_tokens, totalExpertsPerToken}, dl_int32, hidden_states.device()); // expert_weights allocation should be done by derived class since data type could vary - int64_t const size_of_expert_count_histogram = std::max(args->num_experts * 2, 256 * 2); + int64_t const size_of_expert_count_histogram = std::max(totalNumExperts * 2, 256 * 2); expert_count_histogram = alloc_tensor({size_of_expert_count_histogram}, dl_int32, // 256 is the max number of threads per block // and max number of experts hidden_states.device()); int32_t max_num_ctas = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxNumCtasInBatchDim( - args->num_tokens, args->top_k, args->num_experts, tile_tokens_dim); + args->num_tokens, totalExpertsPerToken, totalNumExperts, tile_tokens_dim); cta_idx_xy_to_batch_idx = alloc_tensor({max_num_ctas}, dl_int32, hidden_states.device()); @@ -334,14 +337,17 @@ class FusedMoeLauncher { this->activation_type, this->use_shuffled_weight, this->weight_layout); } + int32_t const effectiveTopK = args->top_k + args->num_fused_shared_experts; + int32_t const effectiveLocalExperts = args->local_num_experts + args->num_fused_shared_experts; + if (moe_tactic == -1) { - moe_tactic = moe_runner->getDefaultValidConfigIndex( - args->top_k, args->hidden_size, args->intermediate_size, args->local_num_experts, - args->num_tokens); + moe_tactic = moe_runner->getDefaultValidConfigIndex(effectiveTopK, args->hidden_size, + args->intermediate_size, + effectiveLocalExperts, args->num_tokens); } auto valid_cfgs = - moe_runner->getValidConfigIndices(args->top_k, args->hidden_size, args->intermediate_size, - args->local_num_experts, args->num_tokens); + moe_runner->getValidConfigIndices(effectiveTopK, args->hidden_size, args->intermediate_size, + effectiveLocalExperts, args->num_tokens); auto valid_it = std::find(valid_cfgs.begin(), valid_cfgs.end(), moe_tactic); FLASHINFER_CHECK(valid_it != valid_cfgs.end(), "Invalid MoE tactic ", moe_tactic, " for tile_N=", tile_tokens_dim, ". Number of valid tactics for this tile is ", @@ -377,8 +383,8 @@ class FusedMoeLauncher { routing_runner.run( args->routing_logits, args->routing_bias, args->num_tokens, args->num_experts, args->top_k, - args->n_group, args->topk_group, args->local_expert_offset, args->local_num_experts, - args->routed_scaling_factor, workspace.routing_expert_indexes, + args->num_fused_shared_experts, args->n_group, args->topk_group, args->local_expert_offset, + args->local_num_experts, args->routed_scaling_factor, workspace.routing_expert_indexes, static_cast(expert_count_histogram.data_ptr()), static_cast(total_num_padded_tokens.data_ptr()), static_cast(expanded_idx_to_permuted_idx.data_ptr()), @@ -910,12 +916,13 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { auto const routing_bias_dtype = routing_bias.has_value() ? routing_bias.value().dtype() : dl_bfloat16; mRoutingBiasDtype = routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32; + int32_t const totalExpertsPerToken = args->top_k + args->num_fused_shared_experts; // Check ndim==2 and size>0 because empty placeholder tensors may have non-null data_ptr bool has_precomputed_weights = expert_weights.ndim() == 2 && expert_weights.size(0) > 0; if (!has_precomputed_weights) { // Allocate expert_weights buffer for routing output - FusedMoeLauncher::expert_weights = - alloc_tensor({args->num_tokens, args->top_k}, dl_bfloat16, hidden_states.device()); + FusedMoeLauncher::expert_weights = alloc_tensor({args->num_tokens, totalExpertsPerToken}, + dl_bfloat16, hidden_states.device()); workspace.expert_weights = FusedMoeLauncher::expert_weights.data_ptr(); } else { workspace.expert_weights = const_cast(expert_weights.data_ptr()); @@ -946,12 +953,13 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { TVM_FFI_ICHECK_EQ(gemm1_weights.dtype(), dl_float8_e4m3fn) << "gemm1_weights must be fp8."; TVM_FFI_ICHECK_EQ(gemm2_weights.dtype(), dl_float8_e4m3fn) << "gemm2_weights must be fp8."; + int64_t const totalLocalExperts = args->local_num_experts + args->num_fused_shared_experts; if (quantization_type == Fp8QuantizationType::DeepSeekFp8) { TVM_FFI_ICHECK_EQ(gemm1_weights_scale.dtype(), dl_float32) << "gemm1_weights_scale must be float."; TVM_FFI_ICHECK_EQ(gemm1_weights_scale.ndim(), 3) << "gemm1_weights_scale must be 3D."; - TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(0), args->local_num_experts) - << "gemm1_weights_scale has incorrect shape."; + TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(0), totalLocalExperts) + << "gemm1_weights_scale has incorrect dim 0."; TVM_FFI_ICHECK_EQ(args->intermediate_size % 128, 0) << "intermediate_size must be a multiple of 128."; TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(1), @@ -971,8 +979,8 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { TVM_FFI_ICHECK_EQ(gemm2_weights_scale.dtype(), dl_float32) << "gemm2_weights_scale must be float."; TVM_FFI_ICHECK_EQ(gemm2_weights_scale.ndim(), 3) << "gemm2_weights_scale must be 3D."; - TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(0), args->local_num_experts) - << "gemm2_weights_scale has incorrect shape."; + TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(0), totalLocalExperts) + << "gemm2_weights_scale has incorrect dim 0."; TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(1), args->hidden_size / 128) << "gemm2_weights_scale has incorrect shape."; TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(2), args->intermediate_size / 128) @@ -1082,8 +1090,9 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { // routing runner to use the pre-computed expert indices from workspace.routing_expert_indexes routing_runner.run( use_precomputed ? nullptr : args->routing_logits, args->routing_bias, args->num_tokens, - args->num_experts, args->top_k, args->n_group, args->topk_group, args->local_expert_offset, - args->local_num_experts, args->routed_scaling_factor, workspace.routing_expert_indexes, + args->num_experts, args->top_k, args->num_fused_shared_experts, args->n_group, + args->topk_group, args->local_expert_offset, args->local_num_experts, + args->routed_scaling_factor, workspace.routing_expert_indexes, static_cast(expert_count_histogram.data_ptr()), static_cast(total_num_padded_tokens.data_ptr()), static_cast(expanded_idx_to_permuted_idx.data_ptr()), @@ -1545,8 +1554,9 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { routing_runner.run( args->routing_logits, args->routing_bias, args->num_tokens, args->num_experts, args->top_k, - args->n_group, args->topk_group, args->local_expert_offset, args->local_num_experts, - args->routed_scaling_factor, static_cast(expert_indices.data_ptr()), + args->num_fused_shared_experts, args->n_group, args->topk_group, args->local_expert_offset, + args->local_num_experts, args->routed_scaling_factor, + static_cast(expert_indices.data_ptr()), static_cast(expert_count_histogram.data_ptr()), static_cast(total_num_padded_tokens.data_ptr()), static_cast(expanded_idx_to_permuted_idx.data_ptr()), @@ -1779,10 +1789,11 @@ Array trtllm_fp8_block_scale_moe( Optional routing_bias, TensorView hidden_states, TensorView hidden_states_scale, TensorView gemm1_weights, TensorView gemm1_weights_scale, TensorView gemm2_weights, TensorView gemm2_weights_scale, TensorView output, int64_t num_experts, int64_t top_k, - Optional n_group, Optional topk_group, int64_t intermediate_size, - int64_t local_expert_offset, int64_t local_num_experts, Optional routed_scaling_factor, - int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout, bool do_finalize, - bool enable_pdl, Array config_index, Fp8QuantizationType quantization_type) { + Optional num_fused_shared_experts, Optional n_group, + Optional topk_group, int64_t intermediate_size, int64_t local_expert_offset, + int64_t local_num_experts, Optional routed_scaling_factor, int64_t routing_method_type, + bool use_shuffled_weight, int64_t weight_layout, bool do_finalize, bool enable_pdl, + Array config_index, Fp8QuantizationType quantization_type) { // Basic type validation auto dtype = hidden_states.dtype(); @@ -1843,9 +1854,13 @@ Array trtllm_fp8_block_scale_moe( auto const num_tokens = hidden_states.size(0); auto const hidden_size = hidden_states.size(1); + int64_t const nFusedShared = num_fused_shared_experts.value_or(0); + int64_t const totalExpertsPerToken = top_k + nFusedShared; + int64_t const totalLocalExperts = local_num_experts + nFusedShared; + auto supported_tile_nums = Fp8BlockScaleLauncher::getSupportedTileNums(quantization_type); - std::set selected_tile_nums = - computeSelectedTileN(supported_tile_nums, num_tokens, top_k, local_num_experts); + std::set selected_tile_nums = computeSelectedTileN( + supported_tile_nums, num_tokens, totalExpertsPerToken, totalLocalExperts); // Create a map of launchers for each tile size std::unordered_map> launchers_map; @@ -1855,6 +1870,7 @@ Array trtllm_fp8_block_scale_moe( auto args = std::make_unique(); args->num_tokens = num_tokens; args->num_experts = num_experts; + args->num_fused_shared_experts = nFusedShared; args->hidden_size = hidden_size; args->hidden_size_output = args->hidden_size; args->top_k = top_k; diff --git a/csrc/trtllm_fused_moe_routing_deepseek.cu b/csrc/trtllm_fused_moe_routing_deepseek.cu index 5408d2d059..16fd56f27e 100644 --- a/csrc/trtllm_fused_moe_routing_deepseek.cu +++ b/csrc/trtllm_fused_moe_routing_deepseek.cu @@ -250,17 +250,28 @@ __global__ void routingMainKernel(KernelParams params) { auto finalScore = OutputT{scoreNorm * params.mRouteScale / redNorm}; // write expert idx out already - auto idxTopK = blockIdx.x * params.mTopK + laneIdx; + auto idxTopK = blockIdx.x * params.mTotalExpertsPerToken + laneIdx; + auto idxShared = blockIdx.x * params.mTotalExpertsPerToken + params.mTopK + laneIdx; if (laneIdx < params.mTopK && params.mPtrTopKPacked != nullptr) { PackedScoreIdx packedScore{static_cast(finalScore), static_cast(expertIdx)}; params.mPtrTopKPacked[idxTopK] = packedScore; } + if (laneIdx < params.mNumFusedSharedExperts && params.mPtrTopKPacked != nullptr) { + PackedScoreIdx packedScore{static_cast(1.0F), + static_cast(params.mNumExperts + laneIdx)}; + params.mPtrTopKPacked[idxShared] = packedScore; + } + if (laneIdx < params.mTopK && params.mPtrTopKWeights != nullptr && params.mPtrTopKIds == nullptr) { params.mPtrTopKWeights[idxTopK] = finalScore; } + + if (laneIdx < params.mNumFusedSharedExperts && params.mPtrTopKWeights != nullptr) { + params.mPtrTopKWeights[idxShared] = static_cast(1.0F); + } } } } @@ -561,6 +572,11 @@ void runImpl(Data& data, void* stream) { FLASHINFER_CHECK(data.mNumLimitedGroups <= MaxNumTopGroups, "Routing kernel expects <= %d top groups, got %d", MaxNumTopGroups, data.mNumLimitedGroups); + + int const numExperts = data.mNumExperts + data.mNumFusedSharedExperts; + int const topK = data.mTopK + data.mNumFusedSharedExperts; + int const numThreadsHist = getMaxNumExperts(numExperts); + // Test limits according to values passed in launch, see definition of LAUNCH_ROUTING_DEEPSEEK if (data.mNumExperts <= NumKimiK2Experts) { FLASHINFER_CHECK( @@ -573,6 +589,9 @@ void runImpl(Data& data, void* stream) { "When NumExperts > NumKimiK2Experts, routing kernel expects topK experts <= %d, got %d", MaxSupportedTopExperts, data.mTopK); } + FLASHINFER_CHECK(topK <= MaxSupportedTopExperts, + "Routing kernel expects topK experts <= %d, got %d", MaxSupportedTopExperts, + topK); FLASHINFER_CHECK(data.mTopK <= WarpSize, "Routing kernel expects top K <= warp size, got %d", data.mTopK); FLASHINFER_CHECK(data.mTopK * data.mNumLimitedGroups <= WarpSize, @@ -598,14 +617,19 @@ void runImpl(Data& data, void* stream) { data.mNumExperts / data.mNumExpertGroups <= WarpSize, "Routing kernel expects #experts per group <= warp size, got %d, data.mNumExpertGroups %d", data.mNumExperts / data.mNumExpertGroups, data.mNumExpertGroups); + + FLASHINFER_CHECK(data.mNumFusedSharedExperts <= WarpSize, + "Number of fused shared experts (%d) must be less than warp size.", + data.mNumFusedSharedExperts); } FLASHINFER_CHECK(data.mNumExperts % 4 == 0, "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts); int const numBlocks = data.mNumTokens; - int const numThreadsHist = getMaxNumExperts(data.mNumExperts); - bool const useSingleCluster = data.mNumTokens <= 1024; + int numThreadsPerCluster = numThreadsHist * NumBlocksPerCluster; + bool const useSingleCluster = + data.mNumTokens <= 1024 && data.mNumTokens * topK <= numThreadsPerCluster; if (!useSingleCluster) { // Reset the global histograms (not used in single-cluster code path). // Cover both for the cooperative and two-kernel code paths. @@ -629,7 +653,7 @@ void runImpl(Data& data, void* stream) { int const numBlocksCoop = 128; // Maximum number of tokens supported by the kernel using a cooperative launch. - int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK; + int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / topK; if (data.mPtrTopKIds == nullptr) { int const numThreadsMain = max(data.mNumExpertGroups * WarpSize, getMaxNumExperts(data.mNumExperts)); @@ -645,6 +669,12 @@ void runImpl(Data& data, void* stream) { stream, data.mNumExpertGroups > 1); } + if (data.mNumFusedSharedExperts > 0) { + data.mNumExperts += data.mNumFusedSharedExperts; + data.mTopK += data.mNumFusedSharedExperts; + data.mNumLocalExperts += data.mNumFusedSharedExperts; + } + if (data.mPtrPermutedIdxSize != nullptr) { if (useSingleCluster) { LAUNCH_ROUTING_DEEPSEEK(data, @@ -659,7 +689,7 @@ void runImpl(Data& data, void* stream) { /*smemSize=*/0, // No dynamic smem stream, data.mNumExpertGroups > 1); } else { - const int32_t expandedIdxSize = data.mNumTokens * data.mTopK; + const int32_t expandedIdxSize = data.mNumTokens * topK; const int32_t histogramEltsPerBlock = 8 * numThreadsHist; const int32_t offsetEltsPerBlock = NumEltsPerOffsetTilePerThread * numThreadsHist; diff --git a/csrc/trtllm_fused_moe_runner.cu b/csrc/trtllm_fused_moe_runner.cu index af48040d0a..ce987cd40f 100644 --- a/csrc/trtllm_fused_moe_runner.cu +++ b/csrc/trtllm_fused_moe_runner.cu @@ -50,14 +50,14 @@ Runner::Runner() {} Runner::Runner(int32_t tileTokensDim) : mTileTokensDim(tileTokensDim) {} void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int32_t numExperts, - int32_t topK, int32_t nGroup, int32_t topkGroup, int32_t localExpertOffset, - int32_t localNumExperts, float routedScalingFactor, int32_t* routingExpertIndexes, - int32_t* expertCountHistogram, int32_t* permutedIdxSize, - int32_t* expandedIdxToPermutedIdx, int32_t* permutedIdxToExpandedIdx, - int32_t* permutedIdxToTokenIdx, void* expertWeights, int32_t* numTokensPerExpert, - int32_t* ctaIdxXyToBatchIdx, int32_t* ctaIdxXyToMnLimit, - int32_t* numNonExitingCtas, btg::Dtype dtypeElt, btg::Dtype dtypeBias, - bool useRoutingScalesOnInput, bool useDeepSeekFp8, + int32_t topK, int32_t numFusedSharedExpert, int32_t nGroup, int32_t topkGroup, + int32_t localExpertOffset, int32_t localNumExperts, float routedScalingFactor, + int32_t* routingExpertIndexes, int32_t* expertCountHistogram, + int32_t* permutedIdxSize, int32_t* expandedIdxToPermutedIdx, + int32_t* permutedIdxToExpandedIdx, int32_t* permutedIdxToTokenIdx, + void* expertWeights, int32_t* numTokensPerExpert, int32_t* ctaIdxXyToBatchIdx, + int32_t* ctaIdxXyToMnLimit, int32_t* numNonExitingCtas, btg::Dtype dtypeElt, + btg::Dtype dtypeBias, bool useRoutingScalesOnInput, bool useDeepSeekFp8, RoutingMethodType routingMethodType, cudaStream_t stream) { if (routingMethodType == RoutingMethodType::DeepSeekV3) { FLASHINFER_CHECK(topK <= 22, "For DeepSeek routing method, must have topK <= 22"); @@ -70,6 +70,8 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 routingData.mDtypeScore = btg::Dtype::Fp32; // for DeepSeek, the score is currently always fp32 routingData.mUsePdl = true; + int32_t const totalExpertsPerToken = topK + numFusedSharedExpert; + // output: routingData.mPtrTopKPacked = routingExpertIndexes; routingData.mPtrExpertCounts = expertCountHistogram; @@ -88,9 +90,11 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 routingData.mPtrScores = reinterpret_cast(routingLogits); routingData.mNumTokens = numTokens; routingData.mNumExperts = numExperts; + routingData.mNumFusedSharedExperts = numFusedSharedExpert; routingData.mNumExpertGroups = nGroup; routingData.mNumLimitedGroups = topkGroup; routingData.mTopK = topK; + routingData.mTotalExpertsPerToken = totalExpertsPerToken; routingData.mPaddingLog2 = computeLog2(mTileTokensDim); routingData.mTileTokensDim = mTileTokensDim; routingData.mLocalExpertsStartIdx = localExpertOffset; @@ -98,8 +102,23 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 routingData.mNumLocalExperts = localNumExperts; routingData.mRouteScale = routedScalingFactor; routingData.mUseRoutingSoftmax = false; + + int32_t const numDevices = (localNumExperts > 0) ? numExperts / localNumExperts : 1; + int32_t const deviceIndex = (localNumExperts > 0) ? localExpertOffset / localNumExperts : 0; + int32_t const baseTokensPerDevice = numTokens / numDevices; + int32_t const remainingTokens = numTokens % numDevices; + + if (deviceIndex < remainingTokens) { + routingData.mSharedExpertTokenOffset = (baseTokensPerDevice + 1) * deviceIndex; + routingData.mSharedExpertNumTokens = baseTokensPerDevice + 1; + } else { + routingData.mSharedExpertTokenOffset = remainingTokens + deviceIndex * baseTokensPerDevice; + routingData.mSharedExpertNumTokens = baseTokensPerDevice; + } moe::dev::routing::routingDeepSeek::run(routingData, stream); } else if (routingMethodType == RoutingMethodType::Llama4) { + FLASHINFER_CHECK(numFusedSharedExpert == 0, + "Llama routing method does not support fusing shared expert"); FLASHINFER_CHECK(topK == 1, "For Llama routing method, must have topK == 1"); if (nGroup > 0 || topkGroup > 0) { FLASHINFER_WARN("For Llama routing method, nGroup/topkGroup is ignored, got ", nGroup, "/", @@ -136,6 +155,8 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 } else if (routingMethodType == RoutingMethodType::Renormalize /* default */ || routingMethodType == RoutingMethodType::RenormalizeNaive /* Softmax -> TopK */ || routingMethodType == RoutingMethodType::TopK /* TopK only (no softmax) */) { + FLASHINFER_CHECK(numFusedSharedExpert == 0, + "Renormalize routing method does not support fusing shared expert"); moe::dev::routing::routingRenormalize::Data routingData; // @@ -475,6 +496,9 @@ void Runner::setOpsData(MoERunnerArgs const& args, MoEWorkspace const& workspace convertSfData.sfLayoutDst = btg::SfLayout::Linear; convertSfData.mUsePdl = true; + int32_t const totalNumExperts = args.num_experts + args.num_fused_shared_experts; + int32_t const totalExpertsPerToken = args.top_k + args.num_fused_shared_experts; + // Setup activation data activationData.mDtypeElt = args.mDtypeElt; activationData.mUsePdl = true; @@ -485,7 +509,7 @@ void Runner::setOpsData(MoERunnerArgs const& args, MoEWorkspace const& workspace activationData.outDqSfsPtr = workspace.activation_output_scale; activationData.innerDim = args.intermediate_size * (isGatedActivation(args.activation_type) ? 2 : 1); - activationData.topK = args.top_k; + activationData.topK = totalExpertsPerToken; activationData.numTokens = args.num_tokens; activationData.expandedIdxToPermutedIdx = workspace.expanded_idx_to_permuted_idx; @@ -509,8 +533,8 @@ void Runner::setOpsData(MoERunnerArgs const& args, MoEWorkspace const& workspace } finalizeData.expandedIdxToPermutedIdx = workspace.expanded_idx_to_permuted_idx; finalizeData.numTokens = args.num_tokens; - finalizeData.numExperts = args.num_experts; - finalizeData.topK = args.top_k; + finalizeData.numExperts = totalNumExperts; + finalizeData.topK = totalExpertsPerToken; // We want to fuse unpadding into the finalize kernel, so we need to use the output hidden size. finalizeData.hiddenDim = args.hidden_size_output.value_or(args.hidden_size); finalizeData.hiddenDimPadded = args.hidden_size; @@ -523,14 +547,17 @@ std::tuple Runner::getWorkspaceSizeInBytes(MoERunnerArgs const FLASHINFER_CHECK(configIndex >= 0 && configIndex < static_cast(mPassingConfigs.size()), "Invalid MoE config index ", configIndex, ", valid range is [0, ", static_cast(mPassingConfigs.size()) - 1, "]."); + int32_t const totalLocalExperts = args.local_num_experts + args.num_fused_shared_experts; + int32_t const totalExpertsPerToken = args.top_k + args.num_fused_shared_experts; + auto const& config = mPassingConfigs[configIndex]; auto workspace_size_fc1 = static_cast(mPermuteGemm1.getWorkspaceSizeInBytes( - args.top_k, args.hidden_size, args.intermediate_size, args.local_num_experts, args.num_tokens, - config.gemm1Config)); + totalExpertsPerToken, args.hidden_size, args.intermediate_size, totalLocalExperts, + args.num_tokens, config.gemm1Config)); auto workspace_size_fc2 = static_cast( - mGemm2.getWorkspaceSizeInBytes(args.top_k, args.hidden_size, args.intermediate_size, - args.local_num_experts, args.num_tokens, config.gemm2Config)); + mGemm2.getWorkspaceSizeInBytes(totalExpertsPerToken, args.hidden_size, args.intermediate_size, + totalLocalExperts, args.num_tokens, config.gemm2Config)); return std::make_tuple(workspace_size_fc1, workspace_size_fc2); } @@ -587,12 +614,15 @@ void Runner::run(MoERunnerArgs const& args, MoEWorkspace const& workspace, int d auto const& config = mPassingConfigs[configIndex]; + int32_t const totalLocalExperts = args.local_num_experts + args.num_fused_shared_experts; + int32_t const totalExpertsPerToken = args.top_k + args.num_fused_shared_experts; + mPermuteGemm1.run(args.hidden_states, hidden_states_scale_linear, args.gemm1_weights, args.gemm1_weights_scale, workspace.token_scales, args.output1_scales_scalar, args.output1_scales_gate_scalar, args.gemm1_bias, args.gemm1_alpha, args.gemm1_beta, args.gemm1_clamp_limit, workspace.gemm1_output, - workspace.gemm1_output_scale, args.top_k, args.hidden_size, - args.intermediate_size, args.local_num_experts, args.num_tokens, + workspace.gemm1_output_scale, totalExpertsPerToken, args.hidden_size, + args.intermediate_size, totalLocalExperts, args.num_tokens, workspace.permuted_idx_to_token_idx, workspace.num_non_exiting_ctas, workspace.total_num_padded_tokens, workspace.cta_idx_xy_to_batch_idx, workspace.cta_idx_xy_to_mn_limit, workspace.bmm1_workspace, @@ -612,11 +642,11 @@ void Runner::run(MoERunnerArgs const& args, MoEWorkspace const& workspace, int d // Run gemm2 mGemm2.run(gemm2_input, gemm2_input_scale, args.gemm2_weights, args.gemm2_weights_scale, args.output2_scales_scalar, args.gemm2_bias, workspace.gemm2_output, - workspace.gemm2_output_scale, args.top_k, args.hidden_size, args.intermediate_size, - args.local_num_experts, args.num_tokens, workspace.num_non_exiting_ctas, - workspace.total_num_padded_tokens, workspace.cta_idx_xy_to_batch_idx, - workspace.cta_idx_xy_to_mn_limit, workspace.bmm2_workspace, device, stream, - config.gemm2Config, enable_pdl); + workspace.gemm2_output_scale, totalExpertsPerToken, args.hidden_size, + args.intermediate_size, totalLocalExperts, args.num_tokens, + workspace.num_non_exiting_ctas, workspace.total_num_padded_tokens, + workspace.cta_idx_xy_to_batch_idx, workspace.cta_idx_xy_to_mn_limit, + workspace.bmm2_workspace, device, stream, config.gemm2Config, enable_pdl); // Run finalize if (args.do_finalize) { diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 7e0760e7b2..894cb585df 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -1177,6 +1177,7 @@ def forward( output, kwargs["num_experts"], self.top_k, + kwargs.get("num_fused_shared_experts", 0), kwargs["n_group"], kwargs["topk_group"], self.intermediate_size, @@ -1657,6 +1658,7 @@ def trtllm_fp8_block_scale_moe_op( enable_pdl: Optional[bool] = None, tune_max_num_tokens: int = 8192, fp8_quantization_type: Fp8QuantizationType = Fp8QuantizationType.DeepSeekFp8, + num_fused_shared_experts: int = 0, ) -> List[torch.Tensor]: # Determine routing mode: compute from logits or use pre-computed if routing_logits is None: @@ -1752,7 +1754,9 @@ def trtllm_fp8_block_scale_moe_op( weight_layout=weight_layout, do_finalize=do_finalize, enable_pdl=enable_pdl, + num_fused_shared_experts=num_fused_shared_experts, ) + _nfse = num_fused_shared_experts if num_fused_shared_experts is not None else 0 # Call the C++ function for block scale MoE intermediate_output = moe_op.trtllm_fp8_block_scale_moe( routing_logits, @@ -1768,6 +1772,7 @@ def trtllm_fp8_block_scale_moe_op( output, num_experts, top_k, + _nfse, n_group, topk_group, intermediate_size, @@ -2543,6 +2548,7 @@ def trtllm_fp8_block_scale_moe( enable_pdl: Optional[bool] = None, tune_max_num_tokens: int = 8192, fp8_quantization_type: Fp8QuantizationType = Fp8QuantizationType.DeepSeekFp8, + num_fused_shared_experts: Optional[int] = None, ) -> Union[List[torch.Tensor], torch.Tensor]: """FP8 block scale MoE operation. @@ -2559,7 +2565,7 @@ def trtllm_fp8_block_scale_moe( - [num_experts, hidden_size, intermediate_size] if weight_layout == WeightLayout.MajorK - [num_experts, hidden_size//128, intermediate_size, 128] if weight_layout == WeightLayout.BlockMajorK gemm2_weights_scale: [num_experts, hidden_size//(32 if mxfp8 else 128), intermediate_size//(32 if mxfp8 else 128)] tensor of second layer block scales - num_experts: Total number of experts + num_experts: Total number of routed experts top_k: Number of experts to route to per token n_group: Number of expert groups topk_group: Number of groups to consider for top-k routing @@ -2575,6 +2581,8 @@ def trtllm_fp8_block_scale_moe( enable_pdl: Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90. tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 8192) fp8_quantization_type: Type of FP8 quantization to use (default: DeepSeekFp8) + num_fused_shared_experts: Number of shared experts to fuse into the MoE kernel (default: None/0). + When > 0, weight tensors must have num_experts + num_fused_shared_experts in the expert dim. Returns: when do_finalize=True, returns the final MoE output. otherwise, returns the intermediate results (gemm2_output, expert_weights, expanded_idx_to_permuted_idx) that need further processing. @@ -2609,6 +2617,7 @@ def trtllm_fp8_block_scale_moe( enable_pdl, tune_max_num_tokens, fp8_quantization_type, + num_fused_shared_experts if num_fused_shared_experts is not None else 0, ) if do_finalize: diff --git a/include/flashinfer/trtllm/fused_moe/RoutingKernel.h b/include/flashinfer/trtllm/fused_moe/RoutingKernel.h index ba90742ce0..6e0b39b4c9 100644 --- a/include/flashinfer/trtllm/fused_moe/RoutingKernel.h +++ b/include/flashinfer/trtllm/fused_moe/RoutingKernel.h @@ -102,6 +102,12 @@ struct DataBase { int32_t mLocalExpertsStartIdx; int32_t mLocalExpertsStrideLog2; int32_t mNumLocalExperts; + + /// For fused shared expert + int32_t mNumFusedSharedExperts; + int32_t mSharedExpertTokenOffset; + int32_t mSharedExpertNumTokens; + int32_t mTotalExpertsPerToken; }; template @@ -135,6 +141,11 @@ struct KernelParamsBase { int32_t mLocalExpertsStrideLog2 = 0; int32_t mNumLocalExperts = 0; + int32_t mNumFusedSharedExperts = 0; + int32_t mSharedExpertTokenOffset = 0; + int32_t mSharedExpertNumTokens = 0; + int32_t mTotalExpertsPerToken = 0; + // Public initialization function - make it a template to accept different Data types template void setBaseParams(DataType const& data) { @@ -158,6 +169,11 @@ struct KernelParamsBase { mLocalExpertsStartIdx = data.mLocalExpertsStartIdx; mLocalExpertsStrideLog2 = data.mLocalExpertsStrideLog2; mNumLocalExperts = data.mNumLocalExperts; + + mNumFusedSharedExperts = data.mNumFusedSharedExperts; + mSharedExpertTokenOffset = data.mSharedExpertTokenOffset; + mSharedExpertNumTokens = data.mSharedExpertNumTokens; + mTotalExpertsPerToken = data.mTotalExpertsPerToken; } }; diff --git a/include/flashinfer/trtllm/fused_moe/runner.h b/include/flashinfer/trtllm/fused_moe/runner.h index 46617e5dbd..5024d6ee5b 100644 --- a/include/flashinfer/trtllm/fused_moe/runner.h +++ b/include/flashinfer/trtllm/fused_moe/runner.h @@ -120,9 +120,9 @@ class Runner { explicit Runner(int32_t tileTokensDim); void run(void* routingLogits, void* routingBias, int32_t numTokens, int32_t numExperts, - int32_t topK, int32_t nGroups, int32_t topkGroups, int32_t localExpertOffset, - int32_t localNumExperts, float routedScalingFactor, int32_t* routingExpertIndexes, - int32_t* expertCountHistogram, int32_t* permutedIdxSize, + int32_t topK, int32_t numFusedSharedExpert, int32_t nGroups, int32_t topkGroups, + int32_t localExpertOffset, int32_t localNumExperts, float routedScalingFactor, + int32_t* routingExpertIndexes, int32_t* expertCountHistogram, int32_t* permutedIdxSize, int32_t* expandedIdxToPermutedIdx, int32_t* permutedIdxToExpandedIdx, int32_t* permutedIdxToTokenIdx, void* expertWeights, int32_t* numTokensPerExpert, int32_t* ctaIdxXyToBatchIdx, int32_t* ctaIdxXyToMnLimit, int32_t* numNonExitingCtas, @@ -287,6 +287,7 @@ struct MoERunnerArgs { int32_t num_tokens{0}; int32_t num_experts{0}; + int32_t num_fused_shared_experts{0}; // Hidden dimension input of MoE block. It might be padded. int32_t hidden_size{0}; // Hidden dimension output of MoE block. It is not padded. diff --git a/tests/moe/test_trtllm_gen_fused_moe.py b/tests/moe/test_trtllm_gen_fused_moe.py index 62f9860644..af68d7880f 100644 --- a/tests/moe/test_trtllm_gen_fused_moe.py +++ b/tests/moe/test_trtllm_gen_fused_moe.py @@ -1113,6 +1113,9 @@ def call_moe( enable_pdl = kwargs.get("enable_pdl") hidden_states_scale = kwargs["hidden_states_scale"] hidden_states_quant = kwargs["hidden_states_quant"] + num_fused_shared_experts = kwargs.get("num_fused_shared_experts", 0) + + num_routed_experts = num_experts - num_fused_shared_experts # Generate block scales and quantize hidden states at runtime hidden_states_fp8 = hidden_states_quant.to(torch.float8_e4m3fn) @@ -1140,13 +1143,13 @@ def call_moe( static_data["gemm1_scales"], static_data["gemm2_weights"], static_data["gemm2_scales"], - num_experts, - top_k, + num_routed_experts, + top_k - num_fused_shared_experts, n_groups, top_k_groups, intermediate_size, 0, - num_experts, + num_routed_experts, routed_scaling, routing_method_type, use_shuffled_weight=static_data["use_shuffled_weight"], @@ -1154,6 +1157,9 @@ def call_moe( enable_pdl=enable_pdl, tune_max_num_tokens=TUNE_MAX_NUM_TOKENS, fp8_quantization_type=quantization_mode, + num_fused_shared_experts=num_fused_shared_experts + if num_fused_shared_experts > 0 + else None, ) return output.to(torch.float) @@ -1631,28 +1637,51 @@ def __init__( self.gemm2_bias = gemm2_bias -def routing_reference(expertLogits, topK, padding): +def routing_reference(expertLogits, topK, padding, num_fused_shared_experts=0): """Reference routing implementation for permutation calculation.""" originalDevice = expertLogits.device expertLogits = expertLogits.cpu() numTokens, numExperts = expertLogits.shape assert topK <= numExperts - numTokensPerExpert = torch.zeros(numExperts, dtype=torch.int64) - expandedTokenIdxToExpert = -torch.ones(numTokens * topK, dtype=torch.int64) - expandedTokenIdxToIdxInExpert = -torch.ones(numTokens * topK, dtype=torch.int64) + numTotalExperts = numExperts + num_fused_shared_experts + totalExpertsPerToken = topK + num_fused_shared_experts + + numTokensPerExpert = torch.zeros(numTotalExperts, dtype=torch.int64) + expandedTokenIdxToExpert = -torch.ones( + numTokens * totalExpertsPerToken, dtype=torch.int64 + ) + expandedTokenIdxToIdxInExpert = -torch.ones( + numTokens * totalExpertsPerToken, dtype=torch.int64 + ) topKLogits, topKIndices = torch.topk(expertLogits, topK, dim=1) + if num_fused_shared_experts > 0: + sharedLogits = torch.ones( + numTokens, num_fused_shared_experts, dtype=topKLogits.dtype + ) + topKLogits = torch.cat((topKLogits, sharedLogits), dim=1) + sharedIndices = ( + torch.arange( + numExperts, + numExperts + num_fused_shared_experts, + dtype=topKIndices.dtype, + ) + .unsqueeze(0) + .expand(numTokens, -1) + ) + topKIndices = torch.cat((topKIndices, sharedIndices), dim=1) + for tokenIdx in range(numTokens): - for k in range(topK): - expandedIdx = tokenIdx * topK + k + for k in range(totalExpertsPerToken): + expandedIdx = tokenIdx * totalExpertsPerToken + k expertIndex = topKIndices[tokenIdx, k] expandedTokenIdxToExpert[expandedIdx] = expertIndex expandedTokenIdxToIdxInExpert[expandedIdx] = numTokensPerExpert[expertIndex] numTokensPerExpert[expertIndex] += 1 - paddedTokensPerExpertPrefixSum = torch.zeros(numExperts + 1, dtype=torch.int64) - for ii in range(numExperts): + paddedTokensPerExpertPrefixSum = torch.zeros(numTotalExperts + 1, dtype=torch.int64) + for ii in range(numTotalExperts): def divUpMul(a, b): return (a + b - 1) // b * b @@ -1660,14 +1689,16 @@ def divUpMul(a, b): paddedTokensPerExpertPrefixSum[ii + 1] = paddedTokensPerExpertPrefixSum[ ii ] + divUpMul(numTokensPerExpert[ii], padding) - permutedBufferSize = paddedTokensPerExpertPrefixSum[numExperts] + permutedBufferSize = paddedTokensPerExpertPrefixSum[numTotalExperts] - expandedTokenIdxToPermutedIdx = -torch.ones(numTokens * topK, dtype=torch.int64) + expandedTokenIdxToPermutedIdx = -torch.ones( + numTokens * totalExpertsPerToken, dtype=torch.int64 + ) permutedIdxToExpandedIdx = -torch.ones(permutedBufferSize, dtype=torch.int64) permutedIdxToTokenIdx = -torch.ones(permutedBufferSize, dtype=torch.int64) for tokenIdx in range(numTokens): - for k in range(topK): - expandedIdx = tokenIdx * topK + k + for k in range(totalExpertsPerToken): + expandedIdx = tokenIdx * totalExpertsPerToken + k expert = expandedTokenIdxToExpert[expandedIdx] offsetWithinExpert = expandedTokenIdxToIdxInExpert[expandedIdx] offsetForExpert = paddedTokensPerExpertPrefixSum[expert] @@ -1743,17 +1774,17 @@ def routing_reference_no_aux( routed_scaling, padding, use_routing_scales_on_input=False, + num_fused_shared_experts=0, ): """Tiered TopK routing used by DeepSeek.""" routing_logits = expert_logits.to(dtype=torch.float, device="cuda") if use_routing_scales_on_input: - # if using routing scales on input, topK == 1 and the score is a plain sigmoid scores = F.sigmoid(routing_logits) else: scores = noaux_tc_ref( routing_logits, routing_bias, n_groups, top_k_groups, top_k, routed_scaling ) - permute_info = routing_reference(scores, top_k, padding) + permute_info = routing_reference(scores, top_k, padding, num_fused_shared_experts) return permute_info, scores @@ -2531,6 +2562,8 @@ def _compute_moe_actual_unified(moe_impl, args_dequant, args, **kwargs): kwargs["weight_processing"], ) + num_fused_shared_experts = kwargs.get("num_fused_shared_experts", 0) + # 2. Call MoE with runtime input quantization + kernel execution kernel_kwargs = { "expert_logits": kwargs["expert_logits"], @@ -2551,6 +2584,7 @@ def _compute_moe_actual_unified(moe_impl, args_dequant, args, **kwargs): "enable_autotune": kwargs.get("enable_autotune", True), "gemm1_bias": args.gemm1_bias, "gemm2_bias": args.gemm2_bias, + "num_fused_shared_experts": num_fused_shared_experts, } return moe_impl.call_moe( @@ -2608,6 +2642,8 @@ def run_moe_test( routed_scaling = routing_config["routed_scaling"] num_experts = routing_config["num_experts"] routing_method_type = routing_config["routing_method_type"] + num_fused_shared_experts = routing_config.get("num_fused_shared_experts", 0) + total_experts = num_experts + num_fused_shared_experts # Validation checks assert top_k <= num_experts @@ -2640,7 +2676,7 @@ def run_moe_test( ) gemm1_weights = torch.randn( ( - num_experts, + total_experts, (2 if is_gated_activation(activation_type) else 1) * intermediate_size, hidden_size, ), @@ -2648,7 +2684,7 @@ def run_moe_test( dtype=torch.bfloat16, ) gemm2_weights = torch.randn( - (num_experts, hidden_size, intermediate_size), + (total_experts, hidden_size, intermediate_size), device="cuda", dtype=torch.bfloat16, ) @@ -2666,6 +2702,7 @@ def run_moe_test( routed_scaling, padding, use_routing_scales_on_input, + num_fused_shared_experts=num_fused_shared_experts, ) elif routing_method_type == RoutingMethodType.Renormalize: permute_info, scores = routing_reference_renormalize( @@ -2711,10 +2748,10 @@ def run_moe_test( # Create arguments for reference computation args = moe_args( num_tokens, - num_experts, + total_experts, hidden_size, intermediate_size, - top_k, + top_k + num_fused_shared_experts, padding, quant_data["hidden_states"], quant_data["hidden_states_scale"], @@ -2758,6 +2795,7 @@ def run_moe_test( enable_pdl=True, hidden_states_quant=inputs_data["hidden_states"], enable_autotune=enable_autotune, + num_fused_shared_experts=num_fused_shared_experts, ) # Compare outputs @@ -3025,6 +3063,42 @@ def test_renormalize_routing( }, id="DSv3", ), + pytest.param( + { + "num_experts": 256, + "top_k": 8, + "padding": 8, + "n_groups": 8, + "top_k_groups": 4, + "routed_scaling": 2.5, + "has_routing_bias": True, + "routing_method_type": RoutingMethodType.DeepSeekV3, + "num_fused_shared_experts": 1, + "compatible_moe_impls": [FP8BlockScaleMoe], + "compatible_intermediate_size": [512], + "compatible_activation_types": [ActivationType.Swiglu], + "enable_autotune": False, + }, + id="DSv3_fused_shared_1", + ), + pytest.param( + { + "num_experts": 256, + "top_k": 8, + "padding": 8, + "n_groups": 8, + "top_k_groups": 4, + "routed_scaling": 2.5, + "has_routing_bias": True, + "routing_method_type": RoutingMethodType.DeepSeekV3, + "num_fused_shared_experts": 2, + "compatible_moe_impls": [FP8BlockScaleMoe], + "compatible_intermediate_size": [512], + "compatible_activation_types": [ActivationType.Swiglu], + "enable_autotune": False, + }, + id="DSv3_fused_shared_2", + ), pytest.param( { "num_experts": 72,