From 46205b48917bca2da96234a6368e2fc0ad37da28 Mon Sep 17 00:00:00 2001 From: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Date: Wed, 28 Jan 2026 09:46:43 +0000 Subject: [PATCH 01/25] Using ActivationType instead of GatedActType, added compiled kernels, adjusted test Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com> --- csrc/trtllm_batched_gemm_runner.cu | 5 + csrc/trtllm_fused_moe_kernel_launcher.cu | 68 +++++----- csrc/trtllm_fused_moe_runner.cu | 87 +++++++++--- flashinfer/__init__.py | 2 +- flashinfer/fused_moe/__init__.py | 4 +- flashinfer/fused_moe/core.py | 32 +++-- .../trtllm/batched_gemm/KernelRunner.h | 16 +++ include/flashinfer/trtllm/fused_moe/runner.h | 55 +++++--- tests/moe/test_trtllm_gen_fused_moe.py | 127 ++++++++++-------- tests/moe/utils.py | 12 +- 10 files changed, 263 insertions(+), 145 deletions(-) diff --git a/csrc/trtllm_batched_gemm_runner.cu b/csrc/trtllm_batched_gemm_runner.cu index f99e766e86..455830a4d3 100644 --- a/csrc/trtllm_batched_gemm_runner.cu +++ b/csrc/trtllm_batched_gemm_runner.cu @@ -109,6 +109,9 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner( continue; } } + if ((int64_t)options.mEltwiseActType != (int64_t)mOptions.eltwiseActType) { + continue; + } if (mOptions.transposeMmaOutput && options.mEpilogueTileM == mOptions.epilogueTileM) { mPassingConfigIndices.push_back(i); @@ -219,6 +222,8 @@ void TrtllmGenBatchedGemmRunner::run( gemmData.mInputBuffers.mPtrSfB = mOptions.transposeMmaOutput ? sfA : sfB; gemmData.mInputBuffers.mPtrScaleC = scaleC; gemmData.mInputBuffers.mPtrScaleGate = scaleGateC; + // TODO amitz-nv: Do we want to pass scaleAct instead of using scaleGateC? + gemmData.mInputBuffers.mPtrScaleAct = scaleGateC; gemmData.mInputBuffers.mPtrPerTokenSfA = mOptions.transposeMmaOutput ? perTokensSfB : perTokensSfA; gemmData.mInputBuffers.mPtrPerTokenSfB = diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index 97a980c0d6..618c17c0da 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -36,7 +36,7 @@ namespace flashinfer { namespace btg = batchedGemm::trtllm::gen; -using tensorrt_llm::kernels::trtllmgen_moe::MoE::GatedActType; +using tensorrt_llm::kernels::trtllmgen_moe::MoE::ActivationType; using tensorrt_llm::kernels::trtllmgen_moe::Routing::RoutingMethodType; using tvm::ffi::Array; using tvm::ffi::Optional; @@ -109,7 +109,7 @@ class FusedMoeLauncher { btg::Dtype mDtypeWeights{btg::Dtype::Bfloat16}; btg::Dtype mRoutingBiasDtype{ btg::Dtype::Bfloat16}; // Dtype for expert weights in routing, based on routing bias - GatedActType gated_act_type{GatedActType::SwiGlu}; + ActivationType activation_type{ActivationType::Swiglu}; public: // Constructor that initializes all TensorView members @@ -134,14 +134,14 @@ class FusedMoeLauncher { weight_layout{batchedGemm::gemm::MatrixLayout::MajorK}, mDtypeAct{btg::Dtype::Bfloat16}, mDtypeWeights{btg::Dtype::Bfloat16}, - gated_act_type{GatedActType::SwiGlu} {} + activation_type{ActivationType::Swiglu} {} protected: // Initialize common data necessary for later. // May throw exception from TVM_FFI_ICHECK. void init_common(std::unique_ptr&& args, int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, - int64_t weight_layout, int64_t gated_act_type); + int64_t weight_layout, ActivationType activation_type); // Routing logits [num_tokens, num_experts] void check_routing_logits_shape() const { @@ -307,7 +307,7 @@ class FusedMoeLauncher { } else { moe_runner = std::make_unique(this->mDtypeAct, this->mDtypeWeights, args->mUseDeepSeekFp8, (int32_t)tile_tokens_dim, - static_cast(this->gated_act_type), + this->activation_type, this->use_shuffled_weight, this->weight_layout); } @@ -377,7 +377,7 @@ class FusedMoeLauncher { void FusedMoeLauncher::init_common( std::unique_ptr&& args, int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, - int64_t weight_layout, int64_t gated_act_type) { + int64_t weight_layout, ActivationType activation_type) { // Check devicearchitecture: Blackwell (SM 10.x) required auto device = hidden_states.device().device_id; int major = 0, minor = 0; @@ -400,9 +400,7 @@ void FusedMoeLauncher::init_common( TVM_FFI_ICHECK(0 <= weight_layout && weight_layout <= 2) << "the value of weight_layout is not recognized"; this->weight_layout = static_cast(weight_layout); - TVM_FFI_ICHECK(0 <= gated_act_type && gated_act_type <= 1) - << "the value of gated_act_type is not recognized"; - this->gated_act_type = static_cast(gated_act_type); + this->activation_type = activation_type; } class Bf16MoeLauncher : public FusedMoeLauncher { @@ -419,12 +417,11 @@ class Bf16MoeLauncher : public FusedMoeLauncher { void init(std::unique_ptr&& args, int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout) { - constexpr int64_t gated_act_type = - static_cast(GatedActType::SwiGlu); // not exposed in api for now + constexpr ActivationType activation_type = ActivationType::Swiglu; // not exposed in api for now // Do base class init and perform common checks FusedMoeLauncher::init_common(std::move(args), tile_tokens_dim, routing_method_type, - use_shuffled_weight, weight_layout, gated_act_type); + use_shuffled_weight, weight_layout, activation_type); } void check_routing() const override { @@ -489,7 +486,7 @@ class Bf16MoeLauncher : public FusedMoeLauncher { static Array> getValidConfigs(int64_t top_k, int64_t hidden_size, int64_t intermediate_size, int64_t num_local_experts, - int64_t num_tokens, int64_t gated_act_type, + int64_t num_tokens, int64_t act_type, bool use_shuffled_weight, int64_t weight_layout) { Array> valid_configs; @@ -502,7 +499,7 @@ class Bf16MoeLauncher : public FusedMoeLauncher { btg::Dtype::Bfloat16, // dtype_act btg::Dtype::Bfloat16, // dtype_weights false, // useDeepSeekFp8 - tile_N, static_cast(gated_act_type), use_shuffled_weight, + tile_N, static_cast(act_type), use_shuffled_weight, static_cast(weight_layout)); auto cfgs = moe_runner->getValidConfigIndices(top_k, hidden_size, intermediate_size, @@ -535,9 +532,7 @@ class Fp8PerTensorLauncher : public FusedMoeLauncher { void init(std::unique_ptr&& args, int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, - int64_t weight_layout, bool use_routing_scales_on_input_param) { - constexpr int64_t gated_act_type = - static_cast(GatedActType::SwiGlu); // not exposed in api for now + int64_t weight_layout, bool use_routing_scales_on_input_param, ActivationType activation_type) { this->use_routing_scales_on_input = use_routing_scales_on_input_param; @@ -554,7 +549,7 @@ class Fp8PerTensorLauncher : public FusedMoeLauncher { mDtypeWeights = btg::Dtype::E4m3; FusedMoeLauncher::init_common(std::move(args), tile_tokens_dim, routing_method_type, - use_shuffled_weight, weight_layout, gated_act_type); + use_shuffled_weight, weight_layout, activation_type); } void check_routing() const override { FusedMoeLauncher::check_routing_common(); } @@ -682,7 +677,7 @@ class Fp8PerTensorLauncher : public FusedMoeLauncher { public: static Array> getValidConfigs(int64_t top_k, int64_t hidden_size, int64_t intermediate_size, int64_t num_local_experts, - int64_t num_tokens, int64_t gated_act_type, + int64_t num_tokens, int64_t act_type, bool use_shuffled_weight, int64_t weight_layout, btg::Dtype dtype_act, btg::Dtype dtype_weights) { Array> valid_configs; @@ -695,7 +690,7 @@ class Fp8PerTensorLauncher : public FusedMoeLauncher { auto moe_runner = std::make_unique( dtype_act, dtype_weights, false, // useDeepSeekFp8 - tile_N, static_cast(gated_act_type), use_shuffled_weight, + tile_N, static_cast(act_type), use_shuffled_weight, static_cast(weight_layout)); auto cfgs = moe_runner->getValidConfigIndices(top_k, hidden_size, intermediate_size, @@ -728,7 +723,7 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { void init(std::unique_ptr&& args, int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout) { - constexpr int64_t gated_act_type = static_cast(GatedActType::SwiGlu); + constexpr ActivationType activation_type = ActivationType::Swiglu; mDtypeAct = btg::Dtype::E4m3; mDtypeWeights = btg::Dtype::E4m3; @@ -748,7 +743,7 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { args->mDtypeOut = btg::Dtype::Bfloat16; FusedMoeLauncher::init_common(std::move(args), tile_tokens_dim, routing_method_type, - use_shuffled_weight, weight_layout, gated_act_type); + use_shuffled_weight, weight_layout, activation_type); } void check_routing() const override { @@ -974,7 +969,7 @@ class MxInt4BlockScaleLauncher : public FusedMoeLauncher { std::move(args), tile_tokens_dim, routing_method_type, /*use_shuffled_weight=*/true, static_cast(batchedGemm::gemm::MatrixLayout::BlockMajorK), - static_cast(GatedActType::SwiGlu)); + ActivationType::Swiglu); } void check_routing() const override { FusedMoeLauncher::check_routing_common(); } @@ -1077,7 +1072,7 @@ class MxInt4BlockScaleLauncher : public FusedMoeLauncher { auto moe_runner = std::make_unique( btg::Dtype::Bfloat16, btg::Dtype::MxInt4, false, // useDeepSeekFp8 - tile_N, GatedActType::SwiGlu, + tile_N, ActivationType::Swiglu, /*useShuffledMatrixA*/ true, batchedGemm::gemm::MatrixLayout::BlockMajorK); auto cfgs = moe_runner->getValidConfigIndices(top_k, hidden_size, intermediate_size, @@ -1132,7 +1127,7 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { void init(std::unique_ptr&& args, int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, - int64_t weight_layout, int64_t gated_act_type, btg::Dtype dtype_act, + int64_t weight_layout, ActivationType activation_type, btg::Dtype dtype_act, btg::Dtype dtype_weights) { static const std::tuple device_props = [this] { int major, minor; @@ -1156,7 +1151,7 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { mDtypeWeights = dtype_weights; FusedMoeLauncher::init_common(std::move(args), tile_tokens_dim, routing_method_type, - use_shuffled_weight, weight_layout, gated_act_type); + use_shuffled_weight, weight_layout, activation_type); } void check_routing() const override { @@ -1376,7 +1371,7 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { static Array> getValidConfigs(int64_t top_k, int64_t hidden_size, int64_t intermediate_size, int64_t num_local_experts, - int64_t num_tokens, int64_t gated_act_type, + int64_t num_tokens, int64_t act_type, btg::Dtype dtype_act, btg::Dtype dtype_weights) { Array> valid_configs; @@ -1388,7 +1383,7 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { auto moe_runner = std::make_unique( dtype_act, dtype_weights, false, // useDeepSeekFp8 - tile_N, static_cast(gated_act_type), + tile_N, static_cast(act_type), /*useShuffledMatrixA*/ true); // FP4 uses shuffled weights auto cfgs = moe_runner->getValidConfigIndices(top_k, hidden_size, intermediate_size, @@ -1482,9 +1477,10 @@ Tensor trtllm_fp8_per_tensor_scale_moe( Optional n_group, Optional topk_group, int64_t intermediate_size, int64_t local_expert_offset, int64_t local_num_experts, Optional routed_scaling_factor, bool use_routing_scales_on_input, int64_t routing_method_type, bool enable_pdl, - Array config_index) { + Array config_index, int64_t activation_type) { // Basic type validation auto dtype = hidden_states.dtype(); + auto activation = static_cast(activation_type); if (use_routing_scales_on_input) { TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_bfloat16) << "routing_logits must be bfloat16."; } else if (static_cast(routing_method_type) == RoutingMethodType::DeepSeekV3) { @@ -1541,7 +1537,7 @@ Tensor trtllm_fp8_per_tensor_scale_moe( routing_logits, routing_bias, hidden_states, gemm1_weights, output1_scales_scalar, output1_scales_gate_scalar, gemm2_weights, output2_scales_scalar); launcher->init(std::move(args), curr_tile_N, routing_method_type, use_shuffled_weight, - weight_layout, use_routing_scales_on_input); + weight_layout, use_routing_scales_on_input, activation); launchers_map[curr_tile_N] = std::move(launcher); } @@ -1659,7 +1655,7 @@ Array trtllm_fp4_block_scale_moe( Optional output2_scales_scalar, int64_t num_experts, int64_t top_k, Optional n_group, Optional topk_group, int64_t intermediate_size, int64_t local_expert_offset, int64_t local_num_experts, Optional routed_scaling_factor, - int64_t routing_method_type, bool do_finalize, bool enable_pdl, int64_t gated_act_type, + int64_t routing_method_type, bool do_finalize, bool enable_pdl, int64_t act_type, TensorView output, Array config_index) { // Determine data types based on input format int const num_tokens = hidden_states.size(0); @@ -1764,7 +1760,7 @@ Array trtllm_fp4_block_scale_moe( gemm2_weights_scale, gemm2_bias, output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar, topk_ids, expert_weights); launcher->init(std::move(args), curr_tile_N, routing_method_type, /*use_shuffled_weight=*/true, - /*weight_layout=*/0, gated_act_type, mDtypeAct, mDtypeWeights); + /*weight_layout=*/0, static_cast(act_type), mDtypeAct, mDtypeWeights); launchers_map[curr_tile_N] = std::move(launcher); } @@ -1877,7 +1873,7 @@ Array trtllm_mxint4_block_scale_moe( Array> trtllm_get_valid_moe_configs( int64_t const dtype_act_, int64_t const dtype_weights_, bool const useDeepSeekFp8, int64_t const top_k, int64_t const hidden_size, int64_t const intermediate_size, - int64_t const num_local_experts, int64_t const gated_act_type, bool const use_shuffled_weight, + int64_t const num_local_experts, int64_t const act_type, bool const use_shuffled_weight, int64_t const weight_layout, int64_t const num_tokens) { auto dtype_act = static_cast(dtype_act_); auto dtype_weights = static_cast(dtype_weights_); @@ -1890,7 +1886,7 @@ Array> trtllm_get_valid_moe_configs( if (dtype_act == btg::Dtype::Bfloat16 && dtype_weights == btg::Dtype::Bfloat16) { // BF16 MoE return Bf16MoeLauncher::getValidConfigs(top_k, hidden_size, intermediate_size, - num_local_experts, num_tokens, gated_act_type, + num_local_experts, num_tokens, act_type, use_shuffled_weight, weight_layout); } else if (dtype_act == btg::Dtype::E4m3 && dtype_weights == btg::Dtype::E4m3) { @@ -1898,7 +1894,7 @@ Array> trtllm_get_valid_moe_configs( if (!useDeepSeekFp8) { // FP8 per-tensor scale return Fp8PerTensorLauncher::getValidConfigs( - top_k, hidden_size, intermediate_size, num_local_experts, num_tokens, gated_act_type, + top_k, hidden_size, intermediate_size, num_local_experts, num_tokens, act_type, use_shuffled_weight, weight_layout, dtype_act, dtype_weights); } else { // FP8 block scale @@ -1909,7 +1905,7 @@ Array> trtllm_get_valid_moe_configs( } else if (dtype_weights == btg::Dtype::E2m1 || dtype_weights == btg::Dtype::MxE2m1) { // FP4 block scale return FP4BlockScaleLauncher::getValidConfigs(top_k, hidden_size, intermediate_size, - num_local_experts, num_tokens, gated_act_type, + num_local_experts, num_tokens, act_type, dtype_act, dtype_weights); } diff --git a/csrc/trtllm_fused_moe_runner.cu b/csrc/trtllm_fused_moe_runner.cu index b5ff5757c9..eef0ba2473 100644 --- a/csrc/trtllm_fused_moe_runner.cu +++ b/csrc/trtllm_fused_moe_runner.cu @@ -189,13 +189,46 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 namespace PermuteGemm1 { +using tensorrt_llm::kernels::trtllmgen_moe::MoE::ActivationType; +using tensorrt_llm::kernels::trtllmgen_moe::MoE::isGatedActivation; +using tensorrt_llm::kernels::trtllmgen_moe::MoE::serializeActivationType; + +static inline ActType activationTypeToGatedActType(ActivationType actType) { + switch (actType) { + case ActivationType::Swiglu: + return ActType::SwiGlu; + case ActivationType::Geglu: + return ActType::GeGlu; + default: + FLASHINFER_CHECK(false, "Unsupported gated activation type ", + serializeActivationType(actType), " of enum ", static_cast(actType)); + } + return ActType::SwiGlu; +} + +static inline EltwiseActType activationTypeToEltwiseActType(ActivationType actType) { + switch (actType) { + case ActivationType::Relu2: + return EltwiseActType::Relu2; + case ActivationType::Identity: + return EltwiseActType::None; + default: + FLASHINFER_CHECK(false, "Unsupported eltwise activation type ", + serializeActivationType(actType), " of enum ", static_cast(actType)); + } + return EltwiseActType::None; +} + tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions getOptions( btg::Dtype dtypeAct, btg::Dtype dtypeWeights, int32_t tileTokensDim, bool useDeepSeekFp8, - MoE::GatedActType gatedActType, bool useShuffledMatrixA, + ActivationType activationType, bool useShuffledMatrixA, batchedGemm::gemm::MatrixLayout weightLayout) { - if (gatedActType == MoE::GatedActType::SwiGlu || gatedActType == MoE::GatedActType::GeGlu) { - ActType actType = - (gatedActType == MoE::GatedActType::SwiGlu) ? ActType::SwiGlu : ActType::GeGlu; + int64_t actTypeInt = static_cast(activationType); + FLASHINFER_CHECK(0 <= actTypeInt && actTypeInt < static_cast(ActivationType::InvalidType), + "Unknown activation type", serializeActivationType(activationType), "of enum", actTypeInt); + bool isGatedAct = isGatedActivation(activationType); + if (isGatedAct) { + ActType actType = activationTypeToGatedActType(activationType); tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions options = { // Swap A and B dtypes because transposeMmaOutput is hardcoded to true .dtypeA = dtypeWeights, @@ -213,20 +246,36 @@ tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions getOptions( .weightLayout = weightLayout}; return options; } else { - FLASHINFER_CHECK(false, "Unimplemented gated act type ", - MoE::serializeGatedActType(gatedActType), " of enum ", (int)gatedActType); + EltwiseActType actType = activationTypeToEltwiseActType(activationType); + tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions options = { + // Swap A and B dtypes because transposeMmaOutput is hardcoded to true + .dtypeA = dtypeWeights, + .dtypeB = dtypeAct, + .dtypeC = dtypeAct, + .eltwiseActType = actType, + .deepSeekFp8 = useDeepSeekFp8, + .fusedAct = false, + .routeAct = true, + .staticBatch = false, + .transposeMmaOutput = true, + .tileSize = tileTokensDim, + .epilogueTileM = 128, + .useShuffledMatrixA = useShuffledMatrixA, + .weightLayout = weightLayout}; + return options; } } Runner::Runner(btg::Dtype dtypeAct, btg::Dtype dtypeWeights, bool useDeepSeekFp8, int tileTokensDim, - MoE::GatedActType gatedActType, bool useShuffledMatrixA, + ActivationType activationType, bool useShuffledMatrixA, batchedGemm::gemm::MatrixLayout weightLayout) : mDtypeAct(dtypeAct), mDtypeWeights(dtypeWeights), mTileTokensDim(tileTokensDim), mRunner(tensorrt_llm::kernels::TrtllmGenBatchedGemmRunner( - getOptions(mDtypeAct, mDtypeWeights, mTileTokensDim, useDeepSeekFp8, gatedActType, - useShuffledMatrixA, weightLayout))) {} + getOptions(mDtypeAct, mDtypeWeights, mTileTokensDim, useDeepSeekFp8, activationType, + useShuffledMatrixA, weightLayout))), + mActType(activationType) {} void Runner::run(void* hiddenState, void* hiddenStateScale, void* weights, void* weightsScale, void* expertWeights, float* outputScalesScalar, float* outputScalesGateScalar, @@ -239,7 +288,8 @@ void Runner::run(void* hiddenState, void* hiddenStateScale, void* weights, void* int device, cudaStream_t stream, int32_t configIndex, bool enable_pdl) { auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); - mRunner.run(numTokens, 2 * intermediateSize, hiddenSize, {}, numTokens, numExperts, + int32_t intermediateSizeFactor = (isGatedActivation(mActType) ? 2 : 1); + mRunner.run(numTokens, intermediateSizeFactor * intermediateSize, hiddenSize, {}, numTokens, numExperts, maxNumCtasInBatchDim, hiddenState, hiddenStateScale, weights, weightsScale, expertWeights, /* perTokensSfB */ nullptr, outputScalesScalar, outputScalesGateScalar, ptrBias, ptrAlpha, ptrBeta, ptrClampLimit, output, outputScale, permutedIdxToTokenIdx, @@ -252,7 +302,7 @@ size_t Runner::getWorkspaceSizeInBytes(int32_t topK, int32_t hiddenSize, int32_t int32_t configIndex) const { auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); - return mRunner.getWorkspaceSizeInBytes(numTokens, 2 * intermediateSize, hiddenSize, {}, numTokens, + return mRunner.getWorkspaceSizeInBytes(numTokens, intermediateSize, hiddenSize, {}, numTokens, numExperts, maxNumCtasInBatchDim, configIndex); } @@ -261,7 +311,7 @@ int32_t Runner::getDefaultValidConfigIndex(int32_t topK, int32_t hiddenSize, int32_t numTokens) const { auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); - return mRunner.getDefaultValidConfigIndex(numTokens, 2 * intermediateSize, hiddenSize, {}, + return mRunner.getDefaultValidConfigIndex(numTokens, intermediateSize, hiddenSize, {}, numTokens, numExperts, maxNumCtasInBatchDim); } @@ -272,7 +322,7 @@ bool Runner::isValidConfigIndex(int32_t configIndex, int32_t topK, int32_t hidde Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); auto const isValid = - mRunner.isValidConfigIndex(configIndex, numTokens, 2 * intermediateSize, hiddenSize, {}, + mRunner.isValidConfigIndex(configIndex, numTokens, intermediateSize, hiddenSize, {}, numTokens, numExperts, maxNumCtasInBatchDim); return isValid; @@ -292,6 +342,7 @@ tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions getOptions( .dtypeA = dtypeWeights, .dtypeB = dtypeAct, .dtypeC = dtypeOut, + .eltwiseActType = EltwiseActType::None, .deepSeekFp8 = useDeepSeekFp8, .fusedAct = false, .routeAct = false, @@ -373,10 +424,10 @@ std::vector Runner::getPassingConfigIndices() const { namespace MoE { Runner::Runner(btg::Dtype dtypeAct, btg::Dtype dtypeWeights, bool useDeepSeekFp8, - int32_t tileTokensDim, GatedActType gatedActType, bool useShuffledMatrixA, + int32_t tileTokensDim, ActivationType activationType, bool useShuffledMatrixA, batchedGemm::gemm::MatrixLayout weightLayout) : mPermuteGemm1(PermuteGemm1::Runner(dtypeAct, dtypeWeights, useDeepSeekFp8, tileTokensDim, - gatedActType, useShuffledMatrixA, weightLayout)), + activationType, useShuffledMatrixA, weightLayout)), mGemm2(Gemm2::Runner(dtypeAct, dtypeWeights, btg::Dtype::Bfloat16, useDeepSeekFp8, tileTokensDim, useShuffledMatrixA, weightLayout)) { auto const& gemm1PassingIndices = mPermuteGemm1.getPassingConfigIndices(); @@ -396,7 +447,7 @@ Runner::Runner(btg::Dtype dtypeAct, btg::Dtype dtypeWeights, bool useDeepSeekFp8 Runner::Runner(btg::Dtype dtypeElt, bool useDeepSeekFp8, int32_t tileTokensDim, bool useShuffledMatrixA, batchedGemm::gemm::MatrixLayout weightLayout) - : Runner(dtypeElt, dtypeElt, useDeepSeekFp8, tileTokensDim, GatedActType::SwiGlu, + : Runner(dtypeElt, dtypeElt, useDeepSeekFp8, tileTokensDim, ActivationType::Swiglu, useShuffledMatrixA, weightLayout) {} void Runner::setOpsData(MoERunnerArgs const& args, MoEWorkspace const& workspace, @@ -420,7 +471,9 @@ void Runner::setOpsData(MoERunnerArgs const& args, MoEWorkspace const& workspace activationData.outPtr = workspace.activation_output; activationData.inDqSfsPtr = workspace.gemm1_output_scale; activationData.outDqSfsPtr = workspace.activation_output_scale; - activationData.innerDim = args.intermediate_size * 2; + activationData.innerDim = + args.intermediate_size * + (isGatedActivation(args.activation_type) ? 2 : 1); activationData.topK = args.top_k; activationData.numTokens = args.num_tokens; activationData.expandedIdxToPermutedIdx = workspace.expanded_idx_to_permuted_idx; diff --git a/flashinfer/__init__.py b/flashinfer/__init__.py index c22b4a0a55..c78ceb215b 100644 --- a/flashinfer/__init__.py +++ b/flashinfer/__init__.py @@ -75,8 +75,8 @@ ) from .fp8_quantization import mxfp8_dequantize_host, mxfp8_quantize from .fused_moe import ( + ActivationType, RoutingMethodType, - GatedActType, cutlass_fused_moe, reorder_rows_for_gated_act_gemm, trtllm_fp4_block_scale_moe, diff --git a/flashinfer/fused_moe/__init__.py b/flashinfer/fused_moe/__init__.py index a34d37f149..f08e9c62e4 100644 --- a/flashinfer/fused_moe/__init__.py +++ b/flashinfer/fused_moe/__init__.py @@ -15,8 +15,8 @@ """ from .core import ( + ActivationType, RoutingMethodType, - GatedActType, WeightLayout, convert_to_block_layout, cutlass_fused_moe, @@ -39,8 +39,8 @@ ) __all__ = [ + "ActivationType", "RoutingMethodType", - "GatedActType", "WeightLayout", "convert_to_block_layout", "cutlass_fused_moe", diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 45d5d11bb0..4c9afd48ad 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -994,7 +994,7 @@ def __init__( use_deepseek_fp8: bool, hidden_size: int, intermediate_size: int, - gated_act_type: int = GatedActType.SwiGlu, + activation_type: int = ActivationType.Swiglu, use_shuffled_weight: bool = False, weight_layout: int = WeightLayout.MajorK, use_packed_weights: bool = False, @@ -1007,7 +1007,7 @@ def __init__( self.top_k = top_k self.hidden_size = hidden_size self.intermediate_size = intermediate_size - self.gated_act_type = GatedActType(gated_act_type) + self.activation_type = ActivationType(activation_type) self.use_shuffled_weight = use_shuffled_weight self.weight_layout = WeightLayout(weight_layout) self.use_packed_weights = use_packed_weights @@ -1035,7 +1035,7 @@ def get_valid_tactics( self.hidden_size, self.intermediate_size, self.num_local_experts, - self.gated_act_type, + self.activation_type, self.use_shuffled_weight, self.weight_layout, num_tokens, @@ -1237,7 +1237,7 @@ def forward( kwargs["routing_method_type"], kwargs["enable_pdl"], kwargs["do_finalize"], - self.gated_act_type, + self.activation_type, output, [-1, -1] if tactic == -1 else tactic, ) @@ -1326,7 +1326,7 @@ def trtllm_bf16_moe_op( intermediate_size=intermediate_size, weight_layout=weight_layout, use_shuffled_weight=use_shuffled_weight, - gated_act_type=GatedActType.SwiGlu, # Default for BF16 + activation_type=ActivationType.Swiglu, # Default for BF16 ) inputs = [output, routing_logits, topk_ids, expert_weights, hidden_states] @@ -1424,6 +1424,7 @@ def trtllm_fp8_per_tensor_scale_moe_op( routing_method_type: int = 0, enable_pdl: Optional[bool] = None, tune_max_num_tokens: int = 8192, + activation_type: ActivationType = ActivationType.Identity, ) -> torch.Tensor: if enable_pdl is None: enable_pdl = device_support_pdl(hidden_states.device) @@ -1458,6 +1459,7 @@ def trtllm_fp8_per_tensor_scale_moe_op( intermediate_size=intermediate_size, weight_layout=WeightLayout.MajorK, use_shuffled_weight=True, + activation_type=activation_type, ) inputs = [output, routing_logits, topk_ids, expert_weights, hidden_states] @@ -1482,6 +1484,7 @@ def trtllm_fp8_per_tensor_scale_moe_op( use_routing_scales_on_input=use_routing_scales_on_input, routing_method_type=routing_method_type, enable_pdl=enable_pdl, + activation_type=activation_type.value, ) # Call the C++ function result = moe_op.trtllm_fp8_per_tensor_scale_moe( @@ -1506,6 +1509,7 @@ def trtllm_fp8_per_tensor_scale_moe_op( routing_method_type, enable_pdl, [-1, -1] if tactic == -1 else tactic, + activation_type.value, ) return result @@ -1722,7 +1726,7 @@ def trtllm_fp4_block_scale_moe_op( routing_method_type: int, do_finalize: bool, enable_pdl: Optional[bool] = None, - gated_act_type: int = 0, + activation_type: int = 0, output: Optional[torch.Tensor] = None, tune_max_num_tokens: int = 8192, ) -> List[torch.Tensor]: @@ -1782,7 +1786,7 @@ def trtllm_fp4_block_scale_moe_op( use_deepseek_fp8=False, hidden_size=hidden_size, intermediate_size=intermediate_size, - gated_act_type=gated_act_type, + activation_type=activation_type, weight_layout=WeightLayout.MajorK, use_shuffled_weight=True, ) @@ -1829,7 +1833,7 @@ def trtllm_fp4_block_scale_moe_op( routing_method_type=routing_method_type, enable_pdl=enable_pdl, do_finalize=do_finalize, - gated_act_type=gated_act_type, + activation_type=activation_type, ) # Call the C++ function for block scale MoE @@ -1863,7 +1867,7 @@ def trtllm_fp4_block_scale_moe_op( routing_method_type, do_finalize, enable_pdl, - gated_act_type, + activation_type, output, [-1, -1] if tactic == -1 else tactic, ) @@ -1908,7 +1912,7 @@ def _fake_trtllm_fp4_block_scale_moe( routing_method_type: int, do_finalize: bool, enable_pdl: bool, - gated_act_type: int, + activation_type: int, output: Optional[torch.Tensor], tune_max_num_tokens: int, ): @@ -1980,7 +1984,7 @@ def trtllm_mxint4_block_scale_moe_op( use_deepseek_fp8=False, hidden_size=hidden_size, intermediate_size=intermediate_size, - gated_act_type=GatedActType.SwiGlu, + activation_type=ActivationType.Swiglu, weight_layout=WeightLayout.BlockMajorK, use_shuffled_weight=True, ) @@ -2187,6 +2191,7 @@ def trtllm_fp8_per_tensor_scale_moe( routing_method_type: int = 0, enable_pdl: Optional[bool] = None, tune_max_num_tokens: int = 8192, + activation_type: ActivationType = ActivationType.Identity, ) -> torch.Tensor: """FP8 per tensor scale MoE operation. @@ -2236,6 +2241,7 @@ def trtllm_fp8_per_tensor_scale_moe( routing_method_type, enable_pdl, tune_max_num_tokens, + activation_type, ) @@ -2346,7 +2352,7 @@ def trtllm_fp4_block_scale_moe( routing_method_type: int = 0, do_finalize: bool = True, enable_pdl: Optional[bool] = None, - gated_act_type: int = 0, + activation_type: int = 0, output: Optional[torch.Tensor] = None, tune_max_num_tokens: int = 8192, ) -> List[torch.Tensor]: @@ -2441,7 +2447,7 @@ def trtllm_fp4_block_scale_moe( routing_method_type, do_finalize, enable_pdl, - gated_act_type, + activation_type, output, tune_max_num_tokens, ) diff --git a/include/flashinfer/trtllm/batched_gemm/KernelRunner.h b/include/flashinfer/trtllm/batched_gemm/KernelRunner.h index 970f1ae494..f73c14a5be 100644 --- a/include/flashinfer/trtllm/batched_gemm/KernelRunner.h +++ b/include/flashinfer/trtllm/batched_gemm/KernelRunner.h @@ -47,11 +47,27 @@ enum class ActType { GeGlu, }; +// Type of the element-wise activation to apply after the Gemm +enum class EltwiseActType { + None = 0, + // Gelu is defined as the following operation: + // act = x0 * phi(x0) + // where x0 is the output of the Gemm + // phi is the CDF of standard normal distribution approximated by + // phi(x) = 0.5 * (1 + tanh(0.7978845608028654 * (x + 0.044715 * x * x * x))) + Gelu, + // Relu2 (also known as squared Relu) is defined as the following operation: + // act = relu(x0) ^ 2 + // where x0 is the output of the Gemm. + Relu2, +}; + struct TrtllmGenBatchedGemmRunnerOptions { batchedGemm::trtllm::gen::Dtype dtypeA; batchedGemm::trtllm::gen::Dtype dtypeB; batchedGemm::trtllm::gen::Dtype dtypeC; ActType actType{ActType::SwiGlu}; + EltwiseActType eltwiseActType{EltwiseActType::None}; bool deepSeekFp8{false}; bool fusedAct{false}; bool routeAct{false}; diff --git a/include/flashinfer/trtllm/fused_moe/runner.h b/include/flashinfer/trtllm/fused_moe/runner.h index 3941a23249..9673df1d2d 100644 --- a/include/flashinfer/trtllm/fused_moe/runner.h +++ b/include/flashinfer/trtllm/fused_moe/runner.h @@ -136,25 +136,47 @@ class Runner { } // namespace Routing namespace MoE { -// The type of gated activation function +// The type of activation function // Please keep this in sync with the counterpart defined in flashinfer/flashinfer/fused_moe/core.py -enum class GatedActType : int64_t { - // SwiGlu - SwiGlu = 0, - // GeGlu - GeGlu = 1, +enum class ActivationType : int64_t { + Gelu = 0, + Relu = 1, + Silu = 2, + Swiglu = 3, + Geglu = 4, + SwigluBias = 5, + Relu2 = 6, + Identity = 7, + InvalidType = 8, // Must be last }; -inline std::string serializeGatedActType(GatedActType gatedActType) { - switch (gatedActType) { - case GatedActType::SwiGlu: - return "SwiGlu"; - case GatedActType::GeGlu: - return "GeGlu"; +inline std::string serializeActivationType(ActivationType activationType) { + switch (activationType) { + case ActivationType::Gelu: + return "Gelu"; + case ActivationType::Relu: + return "Relu"; + case ActivationType::Silu: + return "Silu"; + case ActivationType::Swiglu: + return "Swiglu"; + case ActivationType::Geglu: + return "Geglu"; + case ActivationType::SwigluBias: + return "SwigluBias"; + case ActivationType::Relu2: + return "Relu2"; + case ActivationType::Identity: + return "Identity"; default: - return "InvalidGatedActType"; // TODO throw error + return "InvalidType"; // TODO throw error }; } + +inline bool isGatedActivation(ActivationType activationType) { + return activationType == ActivationType::Swiglu || activationType == ActivationType::Geglu; +} + } // namespace MoE namespace PermuteGemm1 { @@ -162,7 +184,7 @@ class Runner { public: explicit Runner(batchedGemm::trtllm::gen::Dtype dtypeAct, batchedGemm::trtllm::gen::Dtype dtypeWeights, bool useDeepSeekFp8, - int tileTokensDim, MoE::GatedActType gatedActType, bool useShuffledMatrixA, + int tileTokensDim, MoE::ActivationType activationType, bool useShuffledMatrixA, batchedGemm::gemm::MatrixLayout weight_layout); size_t getWorkspaceSizeInBytes(int32_t topK, int32_t hiddenSize, int32_t intermediateSize, @@ -193,6 +215,7 @@ class Runner { batchedGemm::trtllm::gen::Dtype mDtypeWeights; int32_t mTileTokensDim; tensorrt_llm::kernels::TrtllmGenBatchedGemmRunner mRunner; + tensorrt_llm::kernels::trtllmgen_moe::MoE::ActivationType mActType; }; } // namespace PermuteGemm1 @@ -259,6 +282,8 @@ struct MoERunnerArgs { float* gemm1_clamp_limit = nullptr; float* gemm2_bias = nullptr; + ActivationType activation_type = ActivationType::Swiglu; + int32_t num_tokens{0}; int32_t num_experts{0}; // Hidden dimension input of MoE block. It might be padded. @@ -356,7 +381,7 @@ class Runner { // FIXME: tileTokensDim is hardcoded for now Runner(batchedGemm::trtllm::gen::Dtype dtypeAct, batchedGemm::trtllm::gen::Dtype dtypeWeights, bool useDeepSeekFp8, int tileTokensDim = 8, - GatedActType gatedActType = GatedActType::SwiGlu, bool useShuffledMatrixA = false, + ActivationType activationType = ActivationType::Swiglu, bool useShuffledMatrixA = false, batchedGemm::gemm::MatrixLayout weight_layout = batchedGemm::gemm::MatrixLayout::MajorK); Runner(batchedGemm::trtllm::gen::Dtype dtypeElt, bool useDeepSeekFp8, int tileTokensDim = 8, bool useShuffledMatrixA = false, diff --git a/tests/moe/test_trtllm_gen_fused_moe.py b/tests/moe/test_trtllm_gen_fused_moe.py index 89cbf84d4e..3056e051d2 100644 --- a/tests/moe/test_trtllm_gen_fused_moe.py +++ b/tests/moe/test_trtllm_gen_fused_moe.py @@ -22,8 +22,8 @@ from torch.nn import functional as F from flashinfer import ( + ActivationType, RoutingMethodType, - GatedActType, e2m1_and_ufp8sf_scale_to_float, fp4_quantize, mxfp8_dequantize_host, @@ -53,6 +53,10 @@ TUNE_MAX_NUM_TOKENS = 4096 +def is_gated_activation(activation_type: ActivationType) -> bool: + return activation_type in [ActivationType.Swiglu, ActivationType.Geglu] + + def check_cuda(err): """Unified CUDA error checking function used throughout the file.""" if err != runtime.cudaError_t.cudaSuccess: @@ -209,7 +213,7 @@ def _run_moe_computation(self, runtime_args): local_num_experts=self.config["num_experts"], routed_scaling_factor=self.config["routed_scaling"], routing_method_type=self.config["routing_method_type"], - gated_act_type=self.config["gated_act_type"], + activation_type=self.config["activation_type"], do_finalize=True, tune_max_num_tokens=TUNE_MAX_NUM_TOKENS, ) @@ -543,7 +547,7 @@ def call_moe( top_k_groups = kwargs["top_k_groups"] intermediate_size = kwargs["intermediate_size"] routed_scaling = kwargs["routed_scaling"] - gated_act_type = kwargs["gated_act_type"] + activation_type = kwargs["activation_type"] routing_method_type = kwargs["routing_method_type"] enable_autotune = kwargs.get("enable_autotune", True) @@ -556,7 +560,7 @@ def call_moe( "top_k_groups": top_k_groups, "intermediate_size": intermediate_size, "routed_scaling": routed_scaling, - "gated_act_type": gated_act_type, + "activation_type": activation_type, "routing_method_type": routing_method_type, "enable_autotune": enable_autotune, } @@ -1080,14 +1084,16 @@ def prepare_static_weights_for_kernel( # Reorder rows of W1 for fused gated activation gemm1_weights_fp8_interleaved = [] for i in range(num_experts): - gemm1_weights_fp8_interleaved.append( - reorder_rows_for_gated_act_gemm(args.gemm1_weights[i].clone()) - ) + if is_gated_activation(args.activation_type): + weights = reorder_rows_for_gated_act_gemm(args.gemm1_weights[i].clone()) + else: + weights = args.gemm1_weights[i].clone() + gemm1_weights_fp8_interleaved.append(weights) # Stack weights and scales for all experts gemm1_weights_fp8_interleaved = torch.stack( gemm1_weights_fp8_interleaved - ).reshape(num_experts, 2 * intermediate_size, hidden_size) + ).reshape(num_experts, (2 if is_gated_activation(args.activation_type) else 1) * intermediate_size, hidden_size) # Shuffle weights and scaling factors for transposed mma output gemm1_weights_fp8_shuffled = [] @@ -1114,11 +1120,14 @@ def prepare_static_weights_for_kernel( ) # Calculate scaling factors that depend on weights - scale_c_fc1 = ( - args_dequant.c_global_sf - * (1.0 / args.gemm1_scales_global) - * (1.0 / args.hidden_states_scale_global) - ) + if is_gated_activation(args.activation_type): + scale_c_fc1 = ( + args_dequant.c_global_sf + * (1.0 / args.gemm1_scales_global) + * (1.0 / args.hidden_states_scale_global) + ) + else: + scale_c_fc1 = args_dequant.c_global_sf * torch.ones_like(args.gemm1_scales_global) scale_gate_fc1 = (1.0 / args.gemm1_scales_global) * ( 1.0 / args.hidden_states_scale_global ) @@ -1148,6 +1157,7 @@ def call_moe( routed_scaling = kwargs["routed_scaling"] routing_method_type = kwargs["routing_method_type"] enable_autotune = kwargs.get("enable_autotune", True) + activation_type = kwargs["activation_type"] # Quantize to FP8 per-tensor using pre-computed global scale factor hidden_states_fp8, _ = quant_fp8_per_tensor( @@ -1181,6 +1191,7 @@ def call_moe( == RoutingMethodType.Llama4, # Use_routing_scales_on_input routing_method_type, tune_max_num_tokens=TUNE_MAX_NUM_TOKENS, + activation_type=activation_type, ) return output.to(torch.float) @@ -1383,7 +1394,7 @@ def __init__( gemm2_scales_global, permute_info, use_routing_scales_on_input, - gated_act_type, + activation_type, ): self.num_tokens = num_tokens self.num_experts = num_experts @@ -1403,7 +1414,7 @@ def __init__( self.gemm2_scales_global = gemm2_scales_global self.permute_info = permute_info self.use_routing_scales_on_input = use_routing_scales_on_input - self.gated_act_type = gated_act_type + self.activation_type = activation_type class moe_args_dequant: @@ -1423,7 +1434,7 @@ def __init__( gemm2_weights, permute_info, use_routing_scales_on_input, - gated_act_type, + activation_type, hidden_states_scale=None, ): self.num_tokens = num_tokens @@ -1438,7 +1449,7 @@ def __init__( self.gemm2_weights = gemm2_weights self.permute_info = permute_info self.use_routing_scales_on_input = use_routing_scales_on_input - self.gated_act_type = gated_act_type + self.activation_type = activation_type self.hidden_states_scale = hidden_states_scale @@ -1862,7 +1873,7 @@ def run_moe_dequant(args, quant_mode: QuantMode): # Gemm1 gemm1_output = torch.full( - (total_num_padded_tokens, 2 * args.intermediate_size), + (total_num_padded_tokens, (2 if is_gated_activation(args.activation_type) else 1) * args.intermediate_size), float("nan"), device="cuda", ).to(torch.float) @@ -1897,12 +1908,13 @@ def run_moe_dequant(args, quant_mode: QuantMode): (total_num_padded_tokens, args.intermediate_size), float("nan"), device="cuda" ).to(torch.float) - gated_act_type = args.gated_act_type - gated_act_type_to_func = { - 0: F.silu, - 1: F.gelu, + activation_type = args.activation_type + activation_type_to_func = { + ActivationType.Swiglu: F.silu, + ActivationType.Geglu: F.gelu, + ActivationType.Relu2: lambda x: F.relu(x) ** 2, } - gated_act_func = gated_act_type_to_func[gated_act_type] + activation_func = activation_type_to_func[activation_type] i = 0 for expert_idx in range(args.num_experts): @@ -1910,9 +1922,13 @@ def run_moe_dequant(args, quant_mode: QuantMode): if my_num_tokens == 0: continue my_a = gemm1_output[i : i + my_num_tokens] - my_x1 = my_a[:, : args.intermediate_size] - my_x2 = my_a[:, args.intermediate_size :] - activation_output[i : i + my_num_tokens] = gated_act_func(my_x2) * my_x1 + if is_gated_activation(args.activation_type): + my_x1 = my_a[:, : args.intermediate_size] + my_x2 = my_a[:, args.intermediate_size :] + activation_output[i : i + my_num_tokens] = activation_func(my_x2) * my_x1 + else: + my_x1 = my_a[:, : args.intermediate_size] + activation_output[i : i + my_num_tokens] = activation_func(my_x1) i += my_num_tokens i = (i + args.padding - 1) // args.padding * args.padding @@ -2039,7 +2055,7 @@ def run_moe_reference_fp4(args, quant_mode: QuantMode): gemm2_weights_dequant, args.permute_info, args.use_routing_scales_on_input, - args.gated_act_type, + args.activation_type, ) return run_moe_dequant(args_dequant, quant_mode), args_dequant @@ -2104,7 +2120,7 @@ def dequant_reference_dsfp8(input, scale, transpose_scale, block_m, block_n): gemm2_weights_dequant, args.permute_info, args.use_routing_scales_on_input, - GatedActType.SwiGlu.value, # gated_act_type + args.activation_type.value, ) return run_moe_dequant(args_dequant, QuantMode.FP8_BLOCK_SCALE), args_dequant @@ -2141,7 +2157,7 @@ def run_moe_reference_per_tensor_scale_fp8(args): gemm2_weights_dequant, args.permute_info, args.use_routing_scales_on_input, - GatedActType.SwiGlu.value, # gated_act_type + args.activation_type.value, ) return run_moe_dequant(args_dequant, QuantMode.FP8_PER_TENSOR), args_dequant @@ -2172,7 +2188,7 @@ def run_moe_reference_bf16(args): gemm2_weights_dequant, args.permute_info, args.use_routing_scales_on_input, - GatedActType.SwiGlu.value, # gated_act_type + args.activation_type.value, ) return run_moe_dequant(args_dequant, QuantMode.BF16), args_dequant @@ -2223,7 +2239,7 @@ def dequantize(weights, scales): gemm2_weights_dequant, args.permute_info, args.use_routing_scales_on_input, - args.gated_act_type, + args.activation_type, ) return run_moe_dequant(args_dequant, QuantMode.MXINT4_BF16_BF16), args_dequant @@ -2257,7 +2273,7 @@ def _compute_moe_actual_unified(moe_impl, args_dequant, args, **kwargs): "routed_scaling": kwargs["routed_scaling"], "routing_method_type": kwargs["routing_method_type"], "do_finalize": True, - "gated_act_type": args.gated_act_type, + "activation_type": args.activation_type, "hidden_states_scale": args.hidden_states_scale, "hidden_states_quant": kwargs["hidden_states_quant"], "enable_autotune": kwargs.get("enable_autotune", True), @@ -2285,7 +2301,7 @@ def run_moe_test( moe_impl, routing_config, weight_processing, - gated_act_type, + activation_type, cache_permute_indices, zero_hidden_states=False, ): @@ -2294,7 +2310,7 @@ def run_moe_test( moe_impl, routing_config, weight_processing, - gated_act_type, + activation_type, num_tokens, hidden_size, intermediate_size, @@ -2347,7 +2363,7 @@ def run_moe_test( (num_tokens, hidden_size), device="cuda", dtype=torch.bfloat16 ) gemm1_weights = torch.randn( - (num_experts, 2 * intermediate_size, hidden_size), + (num_experts, (2 if is_gated_activation(activation_type) else 1) * intermediate_size, hidden_size), device="cuda", dtype=torch.bfloat16, ) @@ -2432,7 +2448,7 @@ def run_moe_test( quant_data["gemm2_scales_global"], permute_info, use_routing_scales_on_input, - gated_act_type, + activation_type, ) # Compute reference output @@ -2601,10 +2617,10 @@ def run_moe_test( ], ) @pytest.mark.parametrize( - "gated_act_type", + "activation_type", [ - pytest.param(GatedActType.SwiGlu, id="SwiGlu"), - pytest.param(GatedActType.GeGlu, id="GeGlu"), + pytest.param(ActivationType.Swiglu, id="Swiglu"), + pytest.param(ActivationType.Geglu, id="Geglu"), ], ) def test_renormalize_routing( @@ -2614,7 +2630,7 @@ def test_renormalize_routing( moe_impl, routing_config, weight_processing, - gated_act_type, + activation_type, cache_permute_indices, zero_hidden_states, ): @@ -2626,7 +2642,7 @@ def test_renormalize_routing( moe_impl, routing_config, weight_processing, - gated_act_type, + activation_type, cache_permute_indices, zero_hidden_states=zero_hidden_states, ) @@ -2755,10 +2771,10 @@ def test_renormalize_routing( ], ) @pytest.mark.parametrize( - "gated_act_type", + "activation_type", [ - pytest.param(GatedActType.SwiGlu, id="SwiGlu"), - pytest.param(GatedActType.GeGlu, id="GeGlu"), + pytest.param(ActivationType.Swiglu, id="Swiglu"), + pytest.param(ActivationType.Geglu, id="Geglu"), ], ) def test_deepseekv3_routing( @@ -2768,7 +2784,7 @@ def test_deepseekv3_routing( moe_impl, routing_config, weight_processing, - gated_act_type, + activation_type, cache_permute_indices, ): """Test DeepSeekV3 routing configurations.""" @@ -2779,7 +2795,7 @@ def test_deepseekv3_routing( moe_impl, routing_config, weight_processing, - gated_act_type, + activation_type, cache_permute_indices, ) @@ -2830,10 +2846,10 @@ def test_deepseekv3_routing( ], ) @pytest.mark.parametrize( - "gated_act_type", + "activation_type", [ - pytest.param(GatedActType.SwiGlu, id="SwiGlu"), - pytest.param(GatedActType.GeGlu, id="GeGlu"), + pytest.param(ActivationType.Swiglu, id="Swiglu"), + pytest.param(ActivationType.Geglu, id="Geglu"), ], ) def test_topk_routing( @@ -2843,7 +2859,7 @@ def test_topk_routing( moe_impl, routing_config, weight_processing, - gated_act_type, + activation_type, cache_permute_indices, ): """Test TopK routing configuration.""" @@ -2854,7 +2870,7 @@ def test_topk_routing( moe_impl, routing_config, weight_processing, - gated_act_type, + activation_type, cache_permute_indices, ) @@ -2904,9 +2920,10 @@ def test_topk_routing( ], ) @pytest.mark.parametrize( - "gated_act_type", + "activation_type", [ - pytest.param(GatedActType.SwiGlu, id="SwiGlu"), + pytest.param(ActivationType.Swiglu, id="Swiglu"), + pytest.param(ActivationType.Relu2, id="Relu2"), ], ) def test_llama4_routing( @@ -2916,7 +2933,7 @@ def test_llama4_routing( moe_impl, routing_config, weight_processing, - gated_act_type, + activation_type, cache_permute_indices, ): """Test Llama4 routing configuration with FP8 per-tensor.""" @@ -2927,6 +2944,6 @@ def test_llama4_routing( moe_impl, routing_config, weight_processing, - gated_act_type, + activation_type, cache_permute_indices, ) diff --git a/tests/moe/utils.py b/tests/moe/utils.py index 19c01d5175..513882934b 100644 --- a/tests/moe/utils.py +++ b/tests/moe/utils.py @@ -17,7 +17,7 @@ import pytest import torch from enum import IntEnum -from flashinfer import GatedActType, RoutingMethodType +from flashinfer import ActivationType, RoutingMethodType from flashinfer.utils import get_compute_capability @@ -37,7 +37,7 @@ def skip_checks( moe_impl, routing_config, weight_processing, - gated_act_type, + activation_type, num_tokens, hidden_size, intermediate_size, @@ -57,20 +57,20 @@ def skip_checks( pytest.skip("Skipping zero hidden states tests for non-FP8 Block Scale MoE.") # Skip incompatible combinations - if gated_act_type == GatedActType.GeGlu and ( + if activation_type == ActivationType.Geglu and ( not is_fp4_moe or moe_impl.quant_mode != QuantMode.FP4_NVFP4_NVFP4 or routing_config["routing_method_type"] != RoutingMethodType.TopK or num_tokens > 128 ): pytest.skip( - f"Incompatible: {moe_impl.name} + {gated_act_type} + {routing_config['routing_method_type']} + {num_tokens}" + f"Incompatible: {moe_impl.name} + {activation_type} + {routing_config['routing_method_type']} + {num_tokens}" ) - elif gated_act_type == GatedActType.SwiGlu and ( + elif activation_type == ActivationType.Swiglu and ( hidden_size > 1024 or intermediate_size > 1024 ): pytest.skip( - f"Skip for testing speed: {gated_act_type} + {hidden_size} + {intermediate_size}" + f"Skip for testing speed: {activation_type} + {hidden_size} + {intermediate_size}" ) # Skip large intermediate sizes for configurations with many experts From b8eac34895591a6e548166b36a923ccbbf6723c5 Mon Sep 17 00:00:00 2001 From: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Date: Wed, 28 Jan 2026 09:46:56 +0000 Subject: [PATCH 02/25] Add actType and eltwiseActType to 'no kernel found' message, move is_gated_activation function in tests to tests/moe/utils.py Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com> --- csrc/trtllm_batched_gemm_runner.cu | 2 ++ tests/moe/test_trtllm_gen_fused_moe.py | 6 +----- tests/moe/utils.py | 4 ++++ 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/csrc/trtllm_batched_gemm_runner.cu b/csrc/trtllm_batched_gemm_runner.cu index 455830a4d3..34f6fc8ee8 100644 --- a/csrc/trtllm_batched_gemm_runner.cu +++ b/csrc/trtllm_batched_gemm_runner.cu @@ -125,6 +125,8 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner( << ", mDtypeB: " << tg::dtypeToString(mOptions.dtypeB) << ", mDtypeC: " << tg::dtypeToString(mOptions.dtypeC) << ", mUseDeepSeekFp8: " << mOptions.deepSeekFp8 + << ", mActType: " << (int64_t)mOptions.actType + << ", mEltwiseActType: " << (int64_t)mOptions.eltwiseActType << ", mTransposeMmaOutput: " << mOptions.transposeMmaOutput << ", mRouteAct: " << mOptions.routeAct << ", mFusedAct: " << mOptions.fusedAct << ", mIsStaticBatch: " << mOptions.staticBatch << ", mTileSize: " << mOptions.tileSize; diff --git a/tests/moe/test_trtllm_gen_fused_moe.py b/tests/moe/test_trtllm_gen_fused_moe.py index 3056e051d2..27cb2ef077 100644 --- a/tests/moe/test_trtllm_gen_fused_moe.py +++ b/tests/moe/test_trtllm_gen_fused_moe.py @@ -46,17 +46,13 @@ get_w2_permute_indices_with_cache, _maybe_get_cached_w3_w1_permute_indices, ) -from .utils import skip_checks, QuantMode +from .utils import is_gated_activation, skip_checks, QuantMode # Max num tokens to tune for trtllm-gen fused moe TUNE_MAX_NUM_TOKENS = 4096 -def is_gated_activation(activation_type: ActivationType) -> bool: - return activation_type in [ActivationType.Swiglu, ActivationType.Geglu] - - def check_cuda(err): """Unified CUDA error checking function used throughout the file.""" if err != runtime.cudaError_t.cudaSuccess: diff --git a/tests/moe/utils.py b/tests/moe/utils.py index 513882934b..12e2e39efa 100644 --- a/tests/moe/utils.py +++ b/tests/moe/utils.py @@ -33,6 +33,10 @@ class QuantMode(IntEnum): MXINT4_BF16_BF16 = 7 +def is_gated_activation(activation_type: ActivationType) -> bool: + return activation_type in [ActivationType.Swiglu, ActivationType.Geglu] + + def skip_checks( moe_impl, routing_config, From f771e0c2b88501ddf7babc9444bdf381bdd61b5a Mon Sep 17 00:00:00 2001 From: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Date: Wed, 28 Jan 2026 09:46:56 +0000 Subject: [PATCH 03/25] Update remaining GatedActType uses to ActivationType, remove GatedActType enum from core.py, update docstrings Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com> --- .../bench_trtllm_gen_fused_moe_autotuner.py | 4 +- flashinfer/fused_moe/core.py | 48 ++++++++++++------- include/flashinfer/trtllm/fused_moe/runner.h | 2 +- tests/moe/test_dpsk_fused_moe_fp8.py | 4 +- tests/moe/test_trtllm_gen_routed_fused_moe.py | 6 +-- 5 files changed, 38 insertions(+), 26 deletions(-) diff --git a/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py b/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py index 203faaff82..e7d4831661 100644 --- a/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py +++ b/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py @@ -4,7 +4,7 @@ import numpy as np from flashinfer import ( RoutingMethodType, - GatedActType, + ActivationType, fp4_quantize, mxfp8_quantize, ) @@ -288,7 +288,7 @@ def bench_trtllm_gen_fused_moe_autotuner_fp4( RoutingMethodType.Renormalize.value, True, enable_pdl, - GatedActType.SwiGlu.value, # gated_act_type + ActivationType.Swiglu.value, # act_type None, num_tokens if tune_max_num_tokens is None else tune_max_num_tokens, ) diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 4c9afd48ad..864d8db72e 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -173,15 +173,6 @@ class WeightLayout(IntEnum): BlockMajorK = 2 -# The type of gated activation function -# Please keep this in sync with the counterpart defined in include/flashinfer/trtllm/fused_moe/runner.h -class GatedActType(IntEnum): - # SwiGlu - SwiGlu = 0 - # GeGlu - GeGlu = 1 - - @functools.cache def is_trtllm_moe_supported( dtype_weights: DtypeTrtllmGen, @@ -2191,7 +2182,7 @@ def trtllm_fp8_per_tensor_scale_moe( routing_method_type: int = 0, enable_pdl: Optional[bool] = None, tune_max_num_tokens: int = 8192, - activation_type: ActivationType = ActivationType.Identity, + activation_type: int = ActivationType.Identity.value, ) -> torch.Tensor: """FP8 per tensor scale MoE operation. @@ -2216,6 +2207,15 @@ def trtllm_fp8_per_tensor_scale_moe( routing_method_type: Type of routing method to use (default: 0) enable_pdl: Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90. tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 8192) + activation_type (int): Type of activation function (default: 7 - Identity) + - 0: Gelu + - 1: Relu + - 2: Silu + - 3: Swiglu + - 4: Geglu + - 5: SwigluBias + - 6: Relu2 + - 7: Identity Returns: torch.Tensor: Output tensor of shape [seq_len, hidden_size] @@ -2407,9 +2407,15 @@ def trtllm_fp4_block_scale_moe( - 4: RenormalizeNaive (Softmax -> TopK -> Renormalize) do_finalize (bool): Whether to finalize the output (default: False) enable_pdl (Optional[bool]): Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90. - gated_act_type (int): Type of gated activation function (default: 0) - - 0: SwiGlu - - 1: GeGlu + activation_type (int): Type of activation function (default: 0) + - 0: Gelu + - 1: Relu + - 2: Silu + - 3: Swiglu + - 4: Geglu + - 5: SwigluBias + - 6: Relu2 + - 7: Identity tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 8192) output (Optional[torch.Tensor]): shape [seq_len, hidden_size] Optional inplace output tensor. @@ -2482,7 +2488,7 @@ def trtllm_fp4_block_scale_routed_moe( routing_method_type: int = 0, do_finalize: bool = True, enable_pdl: Optional[bool] = None, - gated_act_type: int = 0, + activation_type: int = ActivationType.Swiglu.value, output: Optional[torch.Tensor] = None, tune_max_num_tokens: int = 8192, ) -> List[torch.Tensor]: @@ -2538,9 +2544,15 @@ def trtllm_fp4_block_scale_routed_moe( - 3: Llama4 (Top1 -> Sigmoid) - 4: RenormalizeNaive (Softmax -> TopK -> Renormalize) do_finalize (bool): Whether to finalize the output (default: False) - gated_act_type (int): Type of gated activation function (default: 0) - - 0: SwiGlu - - 1: GeGlu + activation_type (int): Type of activation function (default: 3 - Swiglu) + - 0: Gelu + - 1: Relu + - 2: Silu + - 3: Swiglu + - 4: Geglu + - 5: SwigluBias + - 6: Relu2 + - 7: Identity tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 8192) output (Optional[torch.Tensor]): shape [seq_len, hidden_size] Optional inplace output tensor. @@ -2579,7 +2591,7 @@ def trtllm_fp4_block_scale_routed_moe( routing_method_type, do_finalize, enable_pdl, - gated_act_type, + activation_type, output, tune_max_num_tokens, ) diff --git a/include/flashinfer/trtllm/fused_moe/runner.h b/include/flashinfer/trtllm/fused_moe/runner.h index 9673df1d2d..4cd750b790 100644 --- a/include/flashinfer/trtllm/fused_moe/runner.h +++ b/include/flashinfer/trtllm/fused_moe/runner.h @@ -169,7 +169,7 @@ inline std::string serializeActivationType(ActivationType activationType) { case ActivationType::Identity: return "Identity"; default: - return "InvalidType"; // TODO throw error + return "InvalidActivationType"; // TODO throw error }; } diff --git a/tests/moe/test_dpsk_fused_moe_fp8.py b/tests/moe/test_dpsk_fused_moe_fp8.py index 711e05f234..d1661c1759 100644 --- a/tests/moe/test_dpsk_fused_moe_fp8.py +++ b/tests/moe/test_dpsk_fused_moe_fp8.py @@ -8,7 +8,7 @@ trtllm_fp8_block_scale_moe, ) from .utils import skip_checks, QuantMode -from flashinfer import GatedActType +from flashinfer import ActivationType def dequant_fp8_block_scaled( @@ -616,7 +616,7 @@ def __init__(self): moe_impl=moe_impl, routing_config=routing_config, weight_processing=weight_processing, - gated_act_type=GatedActType.SwiGlu, + gated_act_type=ActivationType.Swiglu, num_tokens=seq_len, hidden_size=7168, # DeepSeek-V3 hidden size intermediate_size=intermediate_size, diff --git a/tests/moe/test_trtllm_gen_routed_fused_moe.py b/tests/moe/test_trtllm_gen_routed_fused_moe.py index fb3feba4b7..dfce612a64 100644 --- a/tests/moe/test_trtllm_gen_routed_fused_moe.py +++ b/tests/moe/test_trtllm_gen_routed_fused_moe.py @@ -20,7 +20,7 @@ from flashinfer import ( RoutingMethodType, - GatedActType, + ActivationType, fp4_quantize, mxfp8_quantize, ) @@ -183,7 +183,7 @@ def test_trtllm_gen_routed_fused_moe( routing_method_type.value, True, # do_finalize enable_pdl, - GatedActType.SwiGlu.value, # gated_act_type + ActivationType.Swiglu.value, # act_type None, )[0].to(torch.float) @@ -236,7 +236,7 @@ def test_trtllm_gen_routed_fused_moe( routing_method_type.value, True, # do_finalize enable_pdl, - GatedActType.SwiGlu.value, # gated_act_type + ActivationType.Swiglu.value, # act_type None, )[0].to(torch.float) From 440c0625c29a4ec84a333c8e1f6761c2d86c1d6b Mon Sep 17 00:00:00 2001 From: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Date: Wed, 28 Jan 2026 09:46:56 +0000 Subject: [PATCH 04/25] Use ActivationType in benchmarks, add missing activation_type argument Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com> --- .../bench_trtllm_gen_fused_moe_autotuner.py | 15 +++++++++++ benchmarks/routines/moe.py | 26 +++++++++---------- flashinfer/fused_moe/core.py | 2 ++ 3 files changed, 29 insertions(+), 14 deletions(-) diff --git a/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py b/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py index e7d4831661..6a2a9d6b53 100644 --- a/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py +++ b/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py @@ -39,6 +39,7 @@ def bench_trtllm_gen_fused_moe_autotuner_fp8( top_k: int, warmups: int, iterations: int, + activation_type: ActivationType, ): device = torch.device("cuda:0") enable_pdl = device_support_pdl(device) @@ -97,6 +98,10 @@ def bench_trtllm_gen_fused_moe_autotuner_fp8( ) if is_block_scale: + if activation_type != ActivationType.Swiglu: + raise ValueError( + "Only Swiglu activation is supported for FP8 block scale MoE." + ) fn = lambda: trtllm_fp8_block_scale_moe( routing_logits, routing_bias, @@ -144,6 +149,7 @@ def bench_trtllm_gen_fused_moe_autotuner_fp8( RoutingMethodType.TopK.value, enable_pdl, num_tokens if tune_max_num_tokens is None else tune_max_num_tokens, + activation_type.value, ) def bench(do_autotune): @@ -348,6 +354,14 @@ def bench(do_autotune): parser.add_argument( "--iterations", type=int, default=100, help="Number of benchmark iterations" ) + parser.add_argument( + "--activation-type", + type=ActivationType, + choices=list(ActivationType), + required=False, + default=ActivationType.Swiglu, + help=f"Type of gated activation function: {list(ActivationType)}", + ) args = parser.parse_args() if args.quant_mode in ["Fp8-Per-Tensor", "Fp8-Block"]: bench_trtllm_gen_fused_moe_autotuner_fp8( @@ -360,6 +374,7 @@ def bench(do_autotune): args.top_k, args.warmups, args.iterations, + args.activation_type, ) else: bench_trtllm_gen_fused_moe_autotuner_fp4( diff --git a/benchmarks/routines/moe.py b/benchmarks/routines/moe.py index ca9214511a..a6c24c3b9a 100644 --- a/benchmarks/routines/moe.py +++ b/benchmarks/routines/moe.py @@ -5,6 +5,7 @@ import torch import flashinfer +from flashinfer import ActivationType from flashinfer.autotuner import autotune from flashinfer.fused_moe import ( WeightLayout, @@ -175,12 +176,12 @@ def parse_moe_args(line, parser): help="Data type of the weights (before quantization).", ) parser.add_argument( - "--gated_act", - type=str, + "--activation-type", + type=ActivationType, + choices=list(ActivationType), required=False, - default="swiglu", - choices=["swiglu", "geglu"], - help="Type of gated activation function: swiglu | geglu.", + default=ActivationType.Swiglu, + help=f"Type of gated activation function: {list(ActivationType)}", ) parser.add_argument( "--autotune", @@ -247,13 +248,7 @@ def parse_moe_args(line, parser): } args.routing_method_type = routing_method_name_to_type[args.routing_method] - # Normalize gated act type (map string to internal int expected by kernels) - gated_act_name_to_type = { - "swiglu": 0, - "geglu": 1, - } - args.gated_act_type = gated_act_name_to_type[args.gated_act] - + args.activation_type = args.activation_type if args.verbose >= 1: print(f"[INFO] {args = }") return args @@ -630,7 +625,7 @@ def testTrtllmFp4BlockScaleMoe(args): use_shuffled_weight = args.use_shuffled_weight weight_layout = args.weight_layout is_cuda_graph_compatible = not args.no_cuda_graph - gated_act_type = args.gated_act_type + activation_type = args.activation_type res = [] backends = ["trtllm"] @@ -795,7 +790,7 @@ def run_fp4_moe( local_num_experts=local_num_experts, routed_scaling_factor=routed_scaling_factor, routing_method_type=routing_method_type, - gated_act_type=gated_act_type, + activation_type=activation_type.value, do_finalize=True, ) @@ -1671,6 +1666,7 @@ def run_fp8_per_tensor_moe( output1_scales_gate_scalar, gemm2_weights_fp8, output2_scales_scalar, + activation_type, ): # Note: FP8 per-tensor MOE expects int64_t for n_group/topk_group, not Optional[int64_t] # So we convert None to 0 to indicate "no groups" mode @@ -1693,6 +1689,7 @@ def run_fp8_per_tensor_moe( routed_scaling_factor=routed_scaling_factor, use_routing_scales_on_input=use_routing_scales_on_input, routing_method_type=routing_method_type, + activation_type=activation_type.value, ) # Benchmark timing @@ -1713,6 +1710,7 @@ def run_fp8_per_tensor_moe( output1_scales_gate_scalar, gemm2_weights_fp8, output2_scales_scalar, + args.activation_type ), ) diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 864d8db72e..602ed7e740 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -1168,6 +1168,7 @@ def forward( kwargs["routing_method_type"], kwargs["enable_pdl"], [-1, -1] if tactic == -1 else tactic, + self.activation_type, ) elif ( self.dtype_act == DtypeTrtllmGen.Bfloat16 @@ -1526,6 +1527,7 @@ def _fake_trtllm_fp8_per_tensor_scale_moe( use_routing_scales_on_input: bool, routing_method_type: int = 0, enable_pdl: Optional[bool] = None, + activation_type: int = ActivationType.Identity.value, ): seq_len = hidden_states.shape[0] hidden_size = hidden_states.shape[1] From 57257396c3983fcf1cf568c0d8dcb350ca1ec45e Mon Sep 17 00:00:00 2001 From: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Date: Wed, 28 Jan 2026 09:46:56 +0000 Subject: [PATCH 05/25] Minor fixes Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com> --- benchmarks/routines/moe.py | 3 +-- tests/moe/test_dpsk_fused_moe_fp8.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/benchmarks/routines/moe.py b/benchmarks/routines/moe.py index a6c24c3b9a..f00260d109 100644 --- a/benchmarks/routines/moe.py +++ b/benchmarks/routines/moe.py @@ -248,7 +248,6 @@ def parse_moe_args(line, parser): } args.routing_method_type = routing_method_name_to_type[args.routing_method] - args.activation_type = args.activation_type if args.verbose >= 1: print(f"[INFO] {args = }") return args @@ -1710,7 +1709,7 @@ def run_fp8_per_tensor_moe( output1_scales_gate_scalar, gemm2_weights_fp8, output2_scales_scalar, - args.activation_type + args.activation_type, ), ) diff --git a/tests/moe/test_dpsk_fused_moe_fp8.py b/tests/moe/test_dpsk_fused_moe_fp8.py index d1661c1759..cd44f2faf2 100644 --- a/tests/moe/test_dpsk_fused_moe_fp8.py +++ b/tests/moe/test_dpsk_fused_moe_fp8.py @@ -616,7 +616,7 @@ def __init__(self): moe_impl=moe_impl, routing_config=routing_config, weight_processing=weight_processing, - gated_act_type=ActivationType.Swiglu, + activation_type=ActivationType.Swiglu, num_tokens=seq_len, hidden_size=7168, # DeepSeek-V3 hidden size intermediate_size=intermediate_size, From c2c8531d08dc79f898633885a4ea298c86ed5862 Mon Sep 17 00:00:00 2001 From: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Date: Wed, 28 Jan 2026 09:55:00 +0000 Subject: [PATCH 06/25] Fix activation_type default value to Swiglu on trtllm_fp4_block_scale_moe Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com> --- flashinfer/fused_moe/core.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 602ed7e740..7ab73bf7eb 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -1719,7 +1719,7 @@ def trtllm_fp4_block_scale_moe_op( routing_method_type: int, do_finalize: bool, enable_pdl: Optional[bool] = None, - activation_type: int = 0, + activation_type: int = ActivationType.Swiglu.value, output: Optional[torch.Tensor] = None, tune_max_num_tokens: int = 8192, ) -> List[torch.Tensor]: @@ -2354,7 +2354,7 @@ def trtllm_fp4_block_scale_moe( routing_method_type: int = 0, do_finalize: bool = True, enable_pdl: Optional[bool] = None, - activation_type: int = 0, + activation_type: int = ActivationType.Swiglu.value, output: Optional[torch.Tensor] = None, tune_max_num_tokens: int = 8192, ) -> List[torch.Tensor]: @@ -2409,7 +2409,7 @@ def trtllm_fp4_block_scale_moe( - 4: RenormalizeNaive (Softmax -> TopK -> Renormalize) do_finalize (bool): Whether to finalize the output (default: False) enable_pdl (Optional[bool]): Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90. - activation_type (int): Type of activation function (default: 0) + activation_type (int): Type of activation function (default: 3 - Swiglu) - 0: Gelu - 1: Relu - 2: Silu From bb4e8214b4d3e65713d18a85f1d214a9c8c40e71 Mon Sep 17 00:00:00 2001 From: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Date: Wed, 28 Jan 2026 09:55:54 +0000 Subject: [PATCH 07/25] Minor improvement Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com> --- tests/moe/test_trtllm_gen_fused_moe.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/moe/test_trtllm_gen_fused_moe.py b/tests/moe/test_trtllm_gen_fused_moe.py index 27cb2ef077..509ef7ec5a 100644 --- a/tests/moe/test_trtllm_gen_fused_moe.py +++ b/tests/moe/test_trtllm_gen_fused_moe.py @@ -1123,7 +1123,9 @@ def prepare_static_weights_for_kernel( * (1.0 / args.hidden_states_scale_global) ) else: - scale_c_fc1 = args_dequant.c_global_sf * torch.ones_like(args.gemm1_scales_global) + scale_c_fc1 = torch.full_like( + args.gemm1_scales_global, args_dequant.c_global_sf + ) scale_gate_fc1 = (1.0 / args.gemm1_scales_global) * ( 1.0 / args.hidden_states_scale_global ) From c6ac4afec4fcd747f1e269b449c7074c32bf428d Mon Sep 17 00:00:00 2001 From: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Date: Wed, 28 Jan 2026 09:55:58 +0000 Subject: [PATCH 08/25] Support non-gated activation in NVFP4 block scale MoE Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com> --- csrc/trtllm_fused_moe_kernel_launcher.cu | 5 +++- flashinfer/fused_moe/core.py | 6 ++++- tests/moe/test_trtllm_gen_fused_moe.py | 29 +++++++++++++++++------- 3 files changed, 30 insertions(+), 10 deletions(-) diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index 618c17c0da..833eb082cf 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -1666,8 +1666,11 @@ Array trtllm_fp4_block_scale_moe( if (hidden_states_scale.has_value()) { hidden_states_scale_vec_size = (num_tokens * hidden_size) / hidden_states_scale.value().numel(); } + int64_t intermediate_size_factor = + isGatedActivation(static_cast(act_type)) ? 2 : 1; int weight_scale_vec_size = - (local_num_experts * intermediate_size * 2 * hidden_size) / gemm1_weights_scale.numel(); + (local_num_experts * intermediate_size * intermediate_size_factor * hidden_size) / + gemm1_weights_scale.numel(); TVM_FFI_ICHECK(weight_scale_vec_size == 16 || weight_scale_vec_size == 32) << "unsupported weight_scale_vec_size."; diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 7ab73bf7eb..124292d75b 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -212,12 +212,16 @@ def _maybe_get_cached_w3_w1_permute_indices( dst_w3_w1_weight: torch.Tensor, epilogue_tile_m: int, num_elts_per_sf: Union[None, int] = None, + is_gated_act_gemm: bool = True, ) -> torch.Tensor: # Create a unique cache key (weight_type, weight_shape) cache_key = ("w3_w1", dst_w3_w1_weight.shape) if cache_key not in _cache_permute_indices: # Get permute indices and chain them together - permute0 = get_reorder_rows_for_gated_act_gemm_row_indices(dst_w3_w1_weight) + if is_gated_act_gemm: + permute0 = get_reorder_rows_for_gated_act_gemm_row_indices(dst_w3_w1_weight) + else: + permute0 = torch.arange(dst_w3_w1_weight.shape[0], dtype=torch.long) if num_elts_per_sf is None: permute1 = get_shuffle_matrix_a_row_indices( dst_w3_w1_weight, epilogue_tile_m=epilogue_tile_m diff --git a/tests/moe/test_trtllm_gen_fused_moe.py b/tests/moe/test_trtllm_gen_fused_moe.py index 509ef7ec5a..ce60313368 100644 --- a/tests/moe/test_trtllm_gen_fused_moe.py +++ b/tests/moe/test_trtllm_gen_fused_moe.py @@ -408,13 +408,16 @@ def prepare_static_weights_for_kernel( ) # Convert quantized weights to proper formats + intermediate_size_factor = 2 if is_gated_activation(args.activation_type) else 1 gemm1_weights_fp4 = args.gemm1_weights.view(torch.float8_e4m3fn).reshape( - num_experts, 2 * intermediate_size, hidden_size // 2 + num_experts, intermediate_size_factor * intermediate_size, hidden_size // 2 ) # packed fp4 gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view( torch.float8_e4m3fn ).reshape( - num_experts, 2 * intermediate_size, hidden_size // self.sf_vec_size + num_experts, + intermediate_size_factor * intermediate_size, + hidden_size // self.sf_vec_size, ) # fp8 scaling factors gemm2_weights_fp4 = args.gemm2_weights.view(torch.float8_e4m3fn).reshape( @@ -440,6 +443,7 @@ def prepare_static_weights_for_kernel( self._cache_permute_indices, gemm1_weights_fp4[i].view(torch.uint8), epilogue_tile_m, + is_gated_act_gemm=is_gated_activation(args.activation_type), ) gemm1_weights_fp4_shuffled.append( gemm1_weights_fp4[i] @@ -452,6 +456,7 @@ def prepare_static_weights_for_kernel( gemm1_scales_linear_fp4[i].view(torch.uint8), epilogue_tile_m, num_elts_per_sf=16, + is_gated_act_gemm=is_gated_activation(args.activation_type), ) gemm1_scales_fp4_shuffled.append( block_scale_interleave( @@ -496,7 +501,9 @@ def prepare_static_weights_for_kernel( torch.stack(gemm1_scales_fp4_shuffled) .view(torch.float8_e4m3fn) .reshape( - num_experts, 2 * intermediate_size, hidden_size // self.sf_vec_size + num_experts, + intermediate_size_factor * intermediate_size, + hidden_size // self.sf_vec_size, ) ) @@ -508,11 +515,16 @@ def prepare_static_weights_for_kernel( ) # Calculate scaling factors that depend on weights - scale_c_fc1 = ( - args_dequant.c_global_sf - * (1.0 / args.gemm1_scales_global) - * (1.0 / args.hidden_states_scale_global) - ) + if is_gated_activation(args.activation_type): + scale_c_fc1 = ( + args_dequant.c_global_sf + * (1.0 / args.gemm1_scales_global) + * (1.0 / args.hidden_states_scale_global) + ) + else: + scale_c_fc1 = torch.full_like( + args.gemm1_scales_global, args_dequant.c_global_sf + ) scale_gate_fc1 = (1.0 / args.gemm1_scales_global) * ( 1.0 / args.hidden_states_scale_global ) @@ -2848,6 +2860,7 @@ def test_deepseekv3_routing( [ pytest.param(ActivationType.Swiglu, id="Swiglu"), pytest.param(ActivationType.Geglu, id="Geglu"), + pytest.param(ActivationType.Relu2, id="Relu2"), ], ) def test_topk_routing( From 3bf918e09aeb7cd6fab7d59e50c73b5bed380d22 Mon Sep 17 00:00:00 2001 From: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Date: Wed, 28 Jan 2026 09:58:28 +0000 Subject: [PATCH 09/25] Rename useShuffledMatrixA to useShuffledMatrix (remove the 'A' suffix) Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com> --- csrc/trtllm_batched_gemm_runner.cu | 3 +- csrc/trtllm_fused_moe_kernel_launcher.cu | 6 ++-- csrc/trtllm_fused_moe_runner.cu | 28 +++++++++---------- csrc/trtllm_low_latency_gemm_runner.cu | 2 +- .../trtllm/batched_gemm/KernelRunner.h | 2 +- include/flashinfer/trtllm/fused_moe/runner.h | 8 +++--- 6 files changed, 24 insertions(+), 25 deletions(-) diff --git a/csrc/trtllm_batched_gemm_runner.cu b/csrc/trtllm_batched_gemm_runner.cu index 34f6fc8ee8..9982974953 100644 --- a/csrc/trtllm_batched_gemm_runner.cu +++ b/csrc/trtllm_batched_gemm_runner.cu @@ -101,8 +101,7 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner( options.mTransposeMmaOutput == mOptions.transposeMmaOutput && (!doesRouteImplUseNoRoute(options.mRouteImpl)) == mOptions.routeAct && options.mFusedAct == mOptions.fusedAct && options.mIsStaticBatch == mOptions.staticBatch && - tileSize == mOptions.tileSize && - options.mUseShuffledMatrix == mOptions.useShuffledMatrixA && + tileSize == mOptions.tileSize && options.mUseShuffledMatrix == mOptions.useShuffledMatrix && options.mLayoutA == mOptions.weightLayout) { if (options.mFusedAct) { if (options.mActType != static_cast(mOptions.actType)) { diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index 833eb082cf..0d61da7305 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -1073,7 +1073,7 @@ class MxInt4BlockScaleLauncher : public FusedMoeLauncher { btg::Dtype::Bfloat16, btg::Dtype::MxInt4, false, // useDeepSeekFp8 tile_N, ActivationType::Swiglu, - /*useShuffledMatrixA*/ true, batchedGemm::gemm::MatrixLayout::BlockMajorK); + /*useShuffledMatrix*/ true, batchedGemm::gemm::MatrixLayout::BlockMajorK); auto cfgs = moe_runner->getValidConfigIndices(top_k, hidden_size, intermediate_size, num_local_experts, num_tokens); @@ -1384,7 +1384,7 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { dtype_act, dtype_weights, false, // useDeepSeekFp8 tile_N, static_cast(act_type), - /*useShuffledMatrixA*/ true); // FP4 uses shuffled weights + /*useShuffledMatrix*/ true); // FP4 uses shuffled weights auto cfgs = moe_runner->getValidConfigIndices(top_k, hidden_size, intermediate_size, num_local_experts, num_tokens); @@ -1505,7 +1505,7 @@ Tensor trtllm_fp8_per_tensor_scale_moe( auto const hidden_size = hidden_states.size(1); // Use default values that match the original function behavior - bool use_shuffled_weight = true; // Original uses /*useShuffledMatrixA*/ true + bool use_shuffled_weight = true; // Original uses /*useShuffledMatrix*/ true int64_t weight_layout = 0; // Default to MajorK // Calculate supported tile sizes diff --git a/csrc/trtllm_fused_moe_runner.cu b/csrc/trtllm_fused_moe_runner.cu index eef0ba2473..608376329a 100644 --- a/csrc/trtllm_fused_moe_runner.cu +++ b/csrc/trtllm_fused_moe_runner.cu @@ -221,7 +221,7 @@ static inline EltwiseActType activationTypeToEltwiseActType(ActivationType actTy tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions getOptions( btg::Dtype dtypeAct, btg::Dtype dtypeWeights, int32_t tileTokensDim, bool useDeepSeekFp8, - ActivationType activationType, bool useShuffledMatrixA, + ActivationType activationType, bool useShuffledMatrix, batchedGemm::gemm::MatrixLayout weightLayout) { int64_t actTypeInt = static_cast(activationType); FLASHINFER_CHECK(0 <= actTypeInt && actTypeInt < static_cast(ActivationType::InvalidType), @@ -242,7 +242,7 @@ tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions getOptions( .transposeMmaOutput = true, .tileSize = tileTokensDim, .epilogueTileM = useDeepSeekFp8 ? 64 : 128, - .useShuffledMatrixA = useShuffledMatrixA, + .useShuffledMatrix = useShuffledMatrix, .weightLayout = weightLayout}; return options; } else { @@ -260,21 +260,21 @@ tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions getOptions( .transposeMmaOutput = true, .tileSize = tileTokensDim, .epilogueTileM = 128, - .useShuffledMatrixA = useShuffledMatrixA, + .useShuffledMatrix = useShuffledMatrix, .weightLayout = weightLayout}; return options; } } Runner::Runner(btg::Dtype dtypeAct, btg::Dtype dtypeWeights, bool useDeepSeekFp8, int tileTokensDim, - ActivationType activationType, bool useShuffledMatrixA, + ActivationType activationType, bool useShuffledMatrix, batchedGemm::gemm::MatrixLayout weightLayout) : mDtypeAct(dtypeAct), mDtypeWeights(dtypeWeights), mTileTokensDim(tileTokensDim), mRunner(tensorrt_llm::kernels::TrtllmGenBatchedGemmRunner( getOptions(mDtypeAct, mDtypeWeights, mTileTokensDim, useDeepSeekFp8, activationType, - useShuffledMatrixA, weightLayout))), + useShuffledMatrix, weightLayout))), mActType(activationType) {} void Runner::run(void* hiddenState, void* hiddenStateScale, void* weights, void* weightsScale, @@ -336,7 +336,7 @@ std::vector Runner::getPassingConfigIndices() const { namespace Gemm2 { tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions getOptions( btg::Dtype dtypeAct, btg::Dtype dtypeWeights, btg::Dtype dtypeOut, int32_t tileTokensDim, - bool useDeepSeekFp8, bool useShuffledMatrixA, batchedGemm::gemm::MatrixLayout weightLayout) { + bool useDeepSeekFp8, bool useShuffledMatrix, batchedGemm::gemm::MatrixLayout weightLayout) { tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions options = { // Swap A and B dtypes because transposeMmaOutput is hardcoded to true .dtypeA = dtypeWeights, @@ -350,13 +350,13 @@ tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions getOptions( .transposeMmaOutput = true, .tileSize = tileTokensDim, .epilogueTileM = useDeepSeekFp8 ? 64 : 128, - .useShuffledMatrixA = useShuffledMatrixA, + .useShuffledMatrix = useShuffledMatrix, .weightLayout = weightLayout}; return options; } Runner::Runner(btg::Dtype dtypeAct, btg::Dtype dtypeWeights, btg::Dtype dtypeOut, - bool useDeepSeekFp8, int tileTokensDim, bool useShuffledMatrixA, + bool useDeepSeekFp8, int tileTokensDim, bool useShuffledMatrix, batchedGemm::gemm::MatrixLayout weightLayout) : mDtypeAct(dtypeAct), mDtypeWeights(dtypeWeights), @@ -364,7 +364,7 @@ Runner::Runner(btg::Dtype dtypeAct, btg::Dtype dtypeWeights, btg::Dtype dtypeOut mTileTokensDim(tileTokensDim), mRunner(tensorrt_llm::kernels::TrtllmGenBatchedGemmRunner( getOptions(dtypeAct, dtypeWeights, dtypeOut, tileTokensDim, useDeepSeekFp8, - useShuffledMatrixA, weightLayout))) {} + useShuffledMatrix, weightLayout))) {} void Runner::run(void* permutedHiddenState, void* permutedHiddenStateScale, void* weights, void* weightsScale, float* outputScalesScalar, float* ptrBias, void* output, @@ -424,12 +424,12 @@ std::vector Runner::getPassingConfigIndices() const { namespace MoE { Runner::Runner(btg::Dtype dtypeAct, btg::Dtype dtypeWeights, bool useDeepSeekFp8, - int32_t tileTokensDim, ActivationType activationType, bool useShuffledMatrixA, + int32_t tileTokensDim, ActivationType activationType, bool useShuffledMatrix, batchedGemm::gemm::MatrixLayout weightLayout) : mPermuteGemm1(PermuteGemm1::Runner(dtypeAct, dtypeWeights, useDeepSeekFp8, tileTokensDim, - activationType, useShuffledMatrixA, weightLayout)), + activationType, useShuffledMatrix, weightLayout)), mGemm2(Gemm2::Runner(dtypeAct, dtypeWeights, btg::Dtype::Bfloat16, useDeepSeekFp8, - tileTokensDim, useShuffledMatrixA, weightLayout)) { + tileTokensDim, useShuffledMatrix, weightLayout)) { auto const& gemm1PassingIndices = mPermuteGemm1.getPassingConfigIndices(); auto const& gemm2PassingIndices = mGemm2.getPassingConfigIndices(); @@ -446,9 +446,9 @@ Runner::Runner(btg::Dtype dtypeAct, btg::Dtype dtypeWeights, bool useDeepSeekFp8 } Runner::Runner(btg::Dtype dtypeElt, bool useDeepSeekFp8, int32_t tileTokensDim, - bool useShuffledMatrixA, batchedGemm::gemm::MatrixLayout weightLayout) + bool useShuffledMatrix, batchedGemm::gemm::MatrixLayout weightLayout) : Runner(dtypeElt, dtypeElt, useDeepSeekFp8, tileTokensDim, ActivationType::Swiglu, - useShuffledMatrixA, weightLayout) {} + useShuffledMatrix, weightLayout) {} void Runner::setOpsData(MoERunnerArgs const& args, MoEWorkspace const& workspace, moe::dev::convertsf::Data& convertSfData, diff --git a/csrc/trtllm_low_latency_gemm_runner.cu b/csrc/trtllm_low_latency_gemm_runner.cu index f3ce0d43c3..99639d9687 100644 --- a/csrc/trtllm_low_latency_gemm_runner.cu +++ b/csrc/trtllm_low_latency_gemm_runner.cu @@ -166,7 +166,7 @@ class TrtllmLowLatencyGemmRunner { configOptions.mDtypeC == mOptions.outputType && configOptions.mTransposeMmaOutput == true && configOptions.mLayoutA == gemm::gemm::MatrixLayout::BlockMajorK && - configOptions.mUseShuffledMatrixA) { + configOptions.mUseShuffledMatrix) { mPassingConfigIndices.push_back(i); } } diff --git a/include/flashinfer/trtllm/batched_gemm/KernelRunner.h b/include/flashinfer/trtllm/batched_gemm/KernelRunner.h index f73c14a5be..54cd824c0e 100644 --- a/include/flashinfer/trtllm/batched_gemm/KernelRunner.h +++ b/include/flashinfer/trtllm/batched_gemm/KernelRunner.h @@ -75,7 +75,7 @@ struct TrtllmGenBatchedGemmRunnerOptions { bool transposeMmaOutput{false}; int32_t tileSize{8}; int32_t epilogueTileM{128}; - bool useShuffledMatrixA{false}; + bool useShuffledMatrix{false}; batchedGemm::gemm::MatrixLayout weightLayout{batchedGemm::gemm::MatrixLayout::MajorK}; }; diff --git a/include/flashinfer/trtllm/fused_moe/runner.h b/include/flashinfer/trtllm/fused_moe/runner.h index 4cd750b790..d8eb59b7be 100644 --- a/include/flashinfer/trtllm/fused_moe/runner.h +++ b/include/flashinfer/trtllm/fused_moe/runner.h @@ -184,7 +184,7 @@ class Runner { public: explicit Runner(batchedGemm::trtllm::gen::Dtype dtypeAct, batchedGemm::trtllm::gen::Dtype dtypeWeights, bool useDeepSeekFp8, - int tileTokensDim, MoE::ActivationType activationType, bool useShuffledMatrixA, + int tileTokensDim, MoE::ActivationType activationType, bool useShuffledMatrix, batchedGemm::gemm::MatrixLayout weight_layout); size_t getWorkspaceSizeInBytes(int32_t topK, int32_t hiddenSize, int32_t intermediateSize, @@ -225,7 +225,7 @@ class Runner { explicit Runner(batchedGemm::trtllm::gen::Dtype dtypeAct, batchedGemm::trtllm::gen::Dtype dtypeWeights, batchedGemm::trtllm::gen::Dtype outputDtype, bool useDeepSeekFp8, - int tileTokensDim, bool useShuffledMatrixA, + int tileTokensDim, bool useShuffledMatrix, batchedGemm::gemm::MatrixLayout weight_layout); size_t getWorkspaceSizeInBytes(int32_t topK, int32_t hiddenSize, int32_t intermediateSize, @@ -381,10 +381,10 @@ class Runner { // FIXME: tileTokensDim is hardcoded for now Runner(batchedGemm::trtllm::gen::Dtype dtypeAct, batchedGemm::trtllm::gen::Dtype dtypeWeights, bool useDeepSeekFp8, int tileTokensDim = 8, - ActivationType activationType = ActivationType::Swiglu, bool useShuffledMatrixA = false, + ActivationType activationType = ActivationType::Swiglu, bool useShuffledMatrix = false, batchedGemm::gemm::MatrixLayout weight_layout = batchedGemm::gemm::MatrixLayout::MajorK); Runner(batchedGemm::trtllm::gen::Dtype dtypeElt, bool useDeepSeekFp8, int tileTokensDim = 8, - bool useShuffledMatrixA = false, + bool useShuffledMatrix = false, batchedGemm::gemm::MatrixLayout weight_layout = batchedGemm::gemm::MatrixLayout::MajorK); void run(MoERunnerArgs const& args, MoEWorkspace const& workspace, int device, From 1193b029848365798695e2777e438d785ca80d26 Mon Sep 17 00:00:00 2001 From: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Date: Wed, 28 Jan 2026 09:58:32 +0000 Subject: [PATCH 10/25] Add FP4_NVFP4_NVFP4 parameterization to test_llama4_routing, update tests skip_checks to skip on non-gated activation with quantizations that don't support it Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com> --- tests/moe/test_trtllm_gen_fused_moe.py | 49 +++++++++++++++++++++++--- tests/moe/utils.py | 16 ++++++++- 2 files changed, 59 insertions(+), 6 deletions(-) diff --git a/tests/moe/test_trtllm_gen_fused_moe.py b/tests/moe/test_trtllm_gen_fused_moe.py index ce60313368..9b9de593f7 100644 --- a/tests/moe/test_trtllm_gen_fused_moe.py +++ b/tests/moe/test_trtllm_gen_fused_moe.py @@ -227,6 +227,12 @@ class Moe(ABC): def __init__(self): self.name = self.__class__.__name__ + @property + @abstractmethod + def quant_mode(self) -> QuantMode: + """Get the quantization mode of this MoE implementation.""" + pass + @abstractmethod def quantize_weights(self, gemm1_weights, gemm2_weights, hidden_states_sample): """Quantize static weights and compute global scale factors (done offline).""" @@ -305,13 +311,17 @@ class FP4Moe(Moe): def __init__(self, quant_mode: QuantMode): super().__init__() - self.quant_mode = quant_mode + self._quant_mode = quant_mode self.is_mxfp4 = ( quant_mode == QuantMode.FP4_MXFP4_MXFP8 or quant_mode == QuantMode.FP4_MXFP4_Bf16 ) self.sf_vec_size = 32 if self.is_mxfp4 else 16 + @property + def quant_mode(self) -> QuantMode: + return self._quant_mode + def quantize_weights(self, gemm1_weights, gemm2_weights, hidden_states_sample): """Quantize weights to FP4 format and compute global scale factors.""" num_experts = gemm1_weights.shape[0] @@ -622,6 +632,10 @@ def mxint4_quantize( class MxInt4BlockScaleMoe(Moe): """MxInt4 MoE implementation with block scaling (DeepSeek style).""" + @property + def quant_mode(self) -> QuantMode: + return QuantMode.MXINT4_BF16_BF16 + def quantize_weights(self, gemm1_weights, gemm2_weights, hidden_states_sample): """Quantize weights to MxInt4 with block scaling.""" num_experts = gemm1_weights.shape[0] @@ -816,6 +830,10 @@ def get_tolerances(self): class FP8BlockScaleMoe(Moe): """FP8 MoE implementation with block scaling (DeepSeek style).""" + @property + def quant_mode(self) -> QuantMode: + return QuantMode.FP8_BLOCK_SCALE + def quantize_weights(self, gemm1_weights, gemm2_weights, hidden_states_sample): """Quantize weights to FP8 with block scaling.""" num_experts = gemm1_weights.shape[0] @@ -1037,6 +1055,10 @@ def get_tolerances(self): class FP8PerTensorMoe(Moe): """FP8 MoE implementation with per-tensor scaling (Llama4 style).""" + @property + def quant_mode(self) -> QuantMode: + return QuantMode.FP8_PER_TENSOR + def quantize_weights(self, gemm1_weights, gemm2_weights, hidden_states_sample): """Quantize weights to FP8 per-tensor and compute global scale factors.""" # Compute global scale factor for hidden states (offline calibration) @@ -1101,7 +1123,11 @@ def prepare_static_weights_for_kernel( # Stack weights and scales for all experts gemm1_weights_fp8_interleaved = torch.stack( gemm1_weights_fp8_interleaved - ).reshape(num_experts, (2 if is_gated_activation(args.activation_type) else 1) * intermediate_size, hidden_size) + ).reshape( + num_experts, + (2 if is_gated_activation(args.activation_type) else 1) * intermediate_size, + hidden_size, + ) # Shuffle weights and scaling factors for transposed mma output gemm1_weights_fp8_shuffled = [] @@ -1223,6 +1249,10 @@ def get_tolerances(self): class BF16Moe(Moe): """BF16 MoE implementation.""" + @property + def quant_mode(self) -> QuantMode: + return QuantMode.BF16 + def quantize_weights(self, gemm1_weights, gemm2_weights, hidden_states_sample): """No scaling for weights.""" return { @@ -1883,7 +1913,11 @@ def run_moe_dequant(args, quant_mode: QuantMode): # Gemm1 gemm1_output = torch.full( - (total_num_padded_tokens, (2 if is_gated_activation(args.activation_type) else 1) * args.intermediate_size), + ( + total_num_padded_tokens, + (2 if is_gated_activation(args.activation_type) else 1) + * args.intermediate_size, + ), float("nan"), device="cuda", ).to(torch.float) @@ -2373,7 +2407,11 @@ def run_moe_test( (num_tokens, hidden_size), device="cuda", dtype=torch.bfloat16 ) gemm1_weights = torch.randn( - (num_experts, (2 if is_gated_activation(activation_type) else 1) * intermediate_size, hidden_size), + ( + num_experts, + (2 if is_gated_activation(activation_type) else 1) * intermediate_size, + hidden_size, + ), device="cuda", dtype=torch.bfloat16, ) @@ -2894,6 +2932,7 @@ def test_topk_routing( "moe_impl", [ pytest.param(FP8PerTensorMoe(), id="FP8_Tensor"), + pytest.param(FP4Moe(QuantMode.FP4_NVFP4_NVFP4), id="FP4"), ], ) @pytest.mark.parametrize( @@ -2909,7 +2948,7 @@ def test_topk_routing( "routed_scaling": 2.5, "has_routing_bias": True, "routing_method_type": RoutingMethodType.Llama4, - "compatible_moe_impls": [FP8PerTensorMoe], + "compatible_moe_impls": [FP8PerTensorMoe, FP4Moe], "compatible_intermediate_size": [1024, 2048], "enable_autotune": True, }, diff --git a/tests/moe/utils.py b/tests/moe/utils.py index 12e2e39efa..ae3fdaab6e 100644 --- a/tests/moe/utils.py +++ b/tests/moe/utils.py @@ -33,6 +33,12 @@ class QuantMode(IntEnum): MXINT4_BF16_BF16 = 7 +NON_GATED_ACTIVATION_SUPPORTED_QUANT_MODES = [ + QuantMode.FP4_NVFP4_NVFP4, + QuantMode.FP8_PER_TENSOR, +] + + def is_gated_activation(activation_type: ActivationType) -> bool: return activation_type in [ActivationType.Swiglu, ActivationType.Geglu] @@ -77,6 +83,14 @@ def skip_checks( f"Skip for testing speed: {activation_type} + {hidden_size} + {intermediate_size}" ) + if ( + not is_gated_activation(activation_type) + and moe_impl.quant_mode not in NON_GATED_ACTIVATION_SUPPORTED_QUANT_MODES + ): + pytest.skip( + f"Incompatible: {moe_impl.name} + {activation_type=} + quant_mode={moe_impl.quant_mode}: non-gated activations only supported with these quant modes: {NON_GATED_ACTIVATION_SUPPORTED_QUANT_MODES}" + ) + # Skip large intermediate sizes for configurations with many experts if routing_config["num_experts"] >= 512 and intermediate_size > 512: pytest.skip( @@ -96,7 +110,7 @@ def skip_checks( f"Incompatible: intermediate_size={intermediate_size} with {routing_config['routing_method_type'].name} routing ({routing_config['num_experts']} experts)" ) - if type(moe_impl).__name__ == "MxInt4BlockScaleMoe" and ( + if moe_impl.quant_mode == QuantMode.MXINT4_BF16_BF16 and ( intermediate_size % 256 != 0 or hidden_size % 256 != 0 ): pytest.skip( From b0e6d599b4f5aa03b530ddc695569b9fa58580e2 Mon Sep 17 00:00:00 2001 From: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Date: Wed, 28 Jan 2026 09:58:33 +0000 Subject: [PATCH 11/25] Increase supported topK and num experts in deepseek routing for nemotron Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com> --- csrc/trtllm_fused_moe_routing_deepseek.cu | 67 +++++++++++++++-------- csrc/trtllm_fused_moe_runner.cu | 2 +- tests/moe/test_trtllm_gen_fused_moe.py | 23 +++++++- tests/moe/utils.py | 2 +- 4 files changed, 66 insertions(+), 28 deletions(-) diff --git a/csrc/trtllm_fused_moe_routing_deepseek.cu b/csrc/trtllm_fused_moe_routing_deepseek.cu index 21faec8ec7..ccb5209e40 100644 --- a/csrc/trtllm_fused_moe_routing_deepseek.cu +++ b/csrc/trtllm_fused_moe_routing_deepseek.cu @@ -15,6 +15,7 @@ */ #include +#include #include "flashinfer/exception.h" #include "flashinfer/trtllm/fused_moe/RoutingKernel.cuh" @@ -25,10 +26,14 @@ namespace routingDeepSeek { //////////////////////////////////////////////////////////////////////////////////////////////////// +static constexpr int NumNemotronExperts = 512; static constexpr int NumKimiK2Experts = 384; static constexpr int NumDeepseekExperts = 256; +static constexpr int MaxSupportedExpertCount = + std::max({NumNemotronExperts, NumKimiK2Experts, NumDeepseekExperts}); static constexpr int NumTopGroupScores = 2; -static constexpr int MaxNumTopExperts = 8; +static constexpr int DefaultMaxNumTopExperts = 8; +static constexpr int MaxSupportedTopExperts = 22; static constexpr int MaxNumTopGroups = 4; static constexpr int MaxNumGroups = 8; @@ -117,8 +122,8 @@ __global__ void routingMainKernel(KernelParams params) { int32_t topGroupIdx[MaxNumTopGroups]; float expertScoreGroup[MaxNumTopGroups]; int32_t expertIdxGroup[MaxNumTopGroups]; - float topScores[MaxNumTopExperts]; // bound of params.mTopK - int32_t topExperts[MaxNumTopExperts]; + float topScores[KernelParams::MaxNumTopExperts]; // bound of params.mTopK + int32_t topExperts[KernelParams::MaxNumTopExperts]; if constexpr (KernelParams::UseGroups) { topk::reduceTopK(warp, topExpGroupScores, topExpGroupIdx, scoreBias, threadExpert, @@ -154,7 +159,8 @@ __global__ void routingMainKernel(KernelParams params) { // params.mNumExpertsPerGroup // => expertIdxGroup[ii] < params.mNumExperts <= NumThreads, // so the access is safe here - expertScoreGroup[ii] = groupIdx < params.mNumExpertGroups && expertSelected + expertScoreGroup[ii] = (ii < params.mNumLimitedGroups) && + (groupIdx < params.mNumExpertGroups) && expertSelected ? smemScoreBias[expertIdxGroup[ii]] : invalidScoreFloat; } @@ -166,7 +172,7 @@ __global__ void routingMainKernel(KernelParams params) { // without groups, each thread just takes `MaxNumTopGroups` experts int constexpr NumExpertWarps = (KernelParams::MaxNumExperts - 1) / topk::MaxNumExpertsUnit + 1; - int constexpr NumInterTopK = NumExpertWarps * MaxNumTopExperts; + int constexpr NumInterTopK = NumExpertWarps * KernelParams::MaxNumTopExperts; __shared__ float __attribute((aligned(128))) smemInterTopScores[NumInterTopK]; __shared__ int32_t __attribute((aligned(128))) smemInterTopExperts[NumInterTopK]; if (warpIdx < NumExpertWarps) { @@ -183,13 +189,20 @@ __global__ void routingMainKernel(KernelParams params) { /* minValue */ invalidScoreFloat, params.mTopK); if (laneIdx < params.mTopK) { - smemInterTopScores[warpIdx * MaxNumTopExperts + laneIdx] = topScores[laneIdx]; - smemInterTopExperts[warpIdx * MaxNumTopExperts + laneIdx] = topExperts[laneIdx]; + smemInterTopScores[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] = + topScores[laneIdx]; + smemInterTopExperts[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] = + topExperts[laneIdx]; + } else if (laneIdx >= params.mTopK && laneIdx < KernelParams::MaxNumTopExperts) { + smemInterTopScores[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] = + invalidScoreFloat; + smemInterTopExperts[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] = + MaxSupportedExpertCount - 1; } } __syncthreads(); if (warpIdx == 0) { - int constexpr NumInterTopKPerThread = (NumInterTopK * NumExpertWarps - 1) / WarpSize + 1; + int constexpr NumInterTopKPerThread = (NumInterTopK - 1) / WarpSize + 1; float intermidiateScore[NumInterTopKPerThread]; int32_t intermidiateExpert[NumInterTopKPerThread]; for (int i = laneIdx; i < NumInterTopKPerThread * WarpSize; i += WarpSize) { @@ -270,7 +283,7 @@ __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) cudaGridDependencySynchronize(); } routingPermutation(params, nullptr, warpIdx, clusterBlockRank); } #else @@ -493,6 +506,8 @@ int constexpr getMaxNumExperts(int32_t numExperts) { return NumDeepseekExperts; } else if (numExperts <= NumKimiK2Experts) { return NumKimiK2Experts; + } else if (numExperts <= NumNemotronExperts) { + return NumNemotronExperts; } else { TLLM_LOG_ERROR("Unsupported numExperts"); return 0; @@ -504,13 +519,23 @@ int constexpr getMaxNumExperts(int32_t numExperts) { extraFlag) \ if (data.mNumExperts <= topk::MaxNumExpertsUnit) { \ LAUNCH_ROUTING_DEEPSEEK_IMPL(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag, topk::MaxNumExpertsUnit); \ + stream, extraFlag, topk::MaxNumExpertsUnit, \ + DefaultMaxNumTopExperts); \ } else if (data.mNumExperts <= NumDeepseekExperts) { \ LAUNCH_ROUTING_DEEPSEEK_IMPL(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag, NumDeepseekExperts); \ + stream, extraFlag, NumDeepseekExperts, DefaultMaxNumTopExperts); \ } else if (data.mNumExperts <= NumKimiK2Experts) { \ LAUNCH_ROUTING_DEEPSEEK_IMPL(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag, NumKimiK2Experts); \ + stream, extraFlag, NumKimiK2Experts, DefaultMaxNumTopExperts); \ + } else if (data.mNumExperts <= NumNemotronExperts) { \ + if (data.mTopK <= DefaultMaxNumTopExperts) { \ + LAUNCH_ROUTING_DEEPSEEK_IMPL(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + stream, extraFlag, NumNemotronExperts, \ + DefaultMaxNumTopExperts); \ + } else { \ + LAUNCH_ROUTING_DEEPSEEK_IMPL(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + stream, extraFlag, NumNemotronExperts, MaxSupportedTopExperts); \ + } \ } else { \ TLLM_LOG_ERROR("Unsupported numExperts"); \ } @@ -532,20 +557,20 @@ void runImpl(Data& data, void* stream) { FLASHINFER_CHECK(data.mNumLimitedGroups <= MaxNumTopGroups, "Routing kernel expects <= %d top groups, got %d", MaxNumTopGroups, data.mNumLimitedGroups); - FLASHINFER_CHECK(data.mTopK <= MaxNumTopExperts, - "Routing kernel expects topK experts <= %d, got %d", MaxNumTopExperts, + FLASHINFER_CHECK(data.mTopK <= MaxSupportedTopExperts, + "Routing kernel expects topK experts <= %d, got %d", MaxSupportedTopExperts, data.mTopK); FLASHINFER_CHECK(data.mTopK <= WarpSize, "Routing kernel expects top K <= warp size, got %d", data.mTopK); FLASHINFER_CHECK(data.mTopK * data.mNumLimitedGroups <= WarpSize, "Routing kernel expects top K * top groups <= warp size (for now), got %d * %d", data.mTopK, data.mNumLimitedGroups); - FLASHINFER_CHECK(data.mNumExperts >= MaxNumTopExperts, - "Routing kernel expects %d to be at most #experts %d", MaxNumTopExperts, + FLASHINFER_CHECK(data.mNumExperts >= MaxSupportedTopExperts, + "Routing kernel expects %d to be at most #experts %d", MaxSupportedTopExperts, data.mNumExperts); - FLASHINFER_CHECK(data.mNumExperts <= NumKimiK2Experts, + FLASHINFER_CHECK(data.mNumExperts <= MaxSupportedExpertCount, "Routing kernel expects #experts %d <= #threads %d", data.mNumExperts, - NumKimiK2Experts); + MaxSupportedExpertCount); FLASHINFER_CHECK(data.mNumExpertGroups >= data.mNumLimitedGroups, "Routing kernel expects top groups %d to be limited by #expert groups %d", data.mNumLimitedGroups, data.mNumExpertGroups); @@ -560,10 +585,6 @@ void runImpl(Data& data, void* stream) { data.mNumExperts / data.mNumExpertGroups <= WarpSize, "Routing kernel expects #experts per group <= warp size, got %d, data.mNumExpertGroups %d", data.mNumExperts / data.mNumExpertGroups, data.mNumExpertGroups); - } else { - FLASHINFER_CHECK(data.mTopK <= topk::MaxNumTopK, - "Routing kernel expects top K %d to be <= #warps %d", data.mTopK, - topk::MaxNumTopK); } FLASHINFER_CHECK(data.mNumExperts % 4 == 0, "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts); @@ -598,7 +619,7 @@ void runImpl(Data& data, void* stream) { int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK; if (data.mPtrTopKIds == nullptr) { int const numThreadsMain = - data.mNumExperts < NumDeepseekExperts ? NumDeepseekExperts : NumKimiK2Experts; + max(data.mNumExpertGroups * WarpSize, getMaxNumExperts(data.mNumExperts)); LAUNCH_ROUTING_DEEPSEEK(data, /*coopLaunch=*/false, routingMainKernel, numBlocks, numThreadsMain, /*smemSize=*/0, // No dynamic smem diff --git a/csrc/trtllm_fused_moe_runner.cu b/csrc/trtllm_fused_moe_runner.cu index 608376329a..3db24f43c1 100644 --- a/csrc/trtllm_fused_moe_runner.cu +++ b/csrc/trtllm_fused_moe_runner.cu @@ -60,7 +60,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 bool useRoutingScalesOnInput, bool useDeepSeekFp8, RoutingMethodType routingMethodType, cudaStream_t stream) { if (routingMethodType == RoutingMethodType::DeepSeekV3) { - FLASHINFER_CHECK(topK <= 8, "For DeepSeek routing method, must have topK <= 8"); + FLASHINFER_CHECK(topK <= 22, "For DeepSeek routing method, must have topK <= 22"); FLASHINFER_CHECK(topkGroup <= 4, "For DeepSeek routing method, must have topkGroup <= 4"); moe::dev::routing::routingDeepSeek::Data routingData; routingData.mDtypeExpW = diff --git a/tests/moe/test_trtllm_gen_fused_moe.py b/tests/moe/test_trtllm_gen_fused_moe.py index 9b9de593f7..4d2d56380e 100644 --- a/tests/moe/test_trtllm_gen_fused_moe.py +++ b/tests/moe/test_trtllm_gen_fused_moe.py @@ -2379,7 +2379,7 @@ def run_moe_test( # Validation checks assert top_k <= num_experts - assert top_k <= 10 + assert top_k <= 22 if (top_k_groups is not None) and (n_groups is not None) and (n_groups > 0): assert top_k_groups <= 4 assert num_experts > n_groups @@ -2699,10 +2699,11 @@ def test_renormalize_routing( # Test: DeepSeekV3 routing @pytest.mark.parametrize("num_tokens", [8, 768, 3072]) @pytest.mark.parametrize("hidden_size", [1024]) -@pytest.mark.parametrize("intermediate_size", [2048, 1024, 768, 512, 384]) +@pytest.mark.parametrize("intermediate_size", [2688, 2048, 1024, 768, 512, 384]) @pytest.mark.parametrize( "moe_impl", [ + pytest.param(FP8PerTensorMoe(), id="FP8_PerTensor"), pytest.param(FP8BlockScaleMoe(), id="FP8_Block"), pytest.param(FP4Moe(quant_mode=QuantMode.FP4_NVFP4_NVFP4), id="NvFP4xNvFP4"), pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_MXFP8), id="MxFP4xMxFP8"), @@ -2714,6 +2715,22 @@ def test_renormalize_routing( @pytest.mark.parametrize( "routing_config", [ + pytest.param( + { + "num_experts": 512, + "top_k": 22, + "padding": 8, + "n_groups": 1, + "top_k_groups": 1, + "routed_scaling": 2.5, + "has_routing_bias": True, + "routing_method_type": RoutingMethodType.DeepSeekV3, + "compatible_moe_impls": [FP8PerTensorMoe, FP4Moe, BF16Moe], + "compatible_intermediate_size": [1024, 2688], + "enable_autotune": True, + }, + id="nemotron_3", + ), pytest.param( { "num_experts": 384, @@ -2823,6 +2840,7 @@ def test_renormalize_routing( [ pytest.param(ActivationType.Swiglu, id="Swiglu"), pytest.param(ActivationType.Geglu, id="Geglu"), + pytest.param(ActivationType.Relu2, id="Relu2"), ], ) def test_deepseekv3_routing( @@ -2898,7 +2916,6 @@ def test_deepseekv3_routing( [ pytest.param(ActivationType.Swiglu, id="Swiglu"), pytest.param(ActivationType.Geglu, id="Geglu"), - pytest.param(ActivationType.Relu2, id="Relu2"), ], ) def test_topk_routing( diff --git a/tests/moe/utils.py b/tests/moe/utils.py index ae3fdaab6e..42f5c06418 100644 --- a/tests/moe/utils.py +++ b/tests/moe/utils.py @@ -92,7 +92,7 @@ def skip_checks( ) # Skip large intermediate sizes for configurations with many experts - if routing_config["num_experts"] >= 512 and intermediate_size > 512: + if routing_config["num_experts"] > 512 and intermediate_size > 512: pytest.skip( f"Skipping for testing speed: intermediate_size={intermediate_size} with {routing_config['num_experts']} experts" ) From d4182ae91c3baafa2632fba0d7675a4fc4e8cdce Mon Sep 17 00:00:00 2001 From: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Date: Wed, 28 Jan 2026 09:58:33 +0000 Subject: [PATCH 12/25] Commit more files for increase supported topK and num experts in deepseek routing for nemotron Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com> --- .../flashinfer/trtllm/fused_moe/DevKernel.h | 105 ++++++++++-------- .../trtllm/fused_moe/RoutingKernel.h | 5 +- 2 files changed, 60 insertions(+), 50 deletions(-) diff --git a/include/flashinfer/trtllm/fused_moe/DevKernel.h b/include/flashinfer/trtllm/fused_moe/DevKernel.h index 23abb87a7b..560063c023 100644 --- a/include/flashinfer/trtllm/fused_moe/DevKernel.h +++ b/include/flashinfer/trtllm/fused_moe/DevKernel.h @@ -169,56 +169,65 @@ namespace moe::dev { FLASHINFER_WARN("Unsupported dtypeExpW"); \ } -#define LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, \ - smemSize, stream, extraFlag, numExperts) \ - if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Fp32 && \ - data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, float, numExperts, extraFlag), kernel, \ - numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Fp32 && \ - data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, __nv_bfloat16, numExperts, extraFlag), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ - data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, float, numExperts, extraFlag), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ - data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_TILEN(data, coopLaunch, \ - LAUNCH_ESC(float, __nv_bfloat16, __nv_bfloat16, numExperts, extraFlag), kernel, \ - numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Fp32 && \ - data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, float, float, numExperts, extraFlag), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Fp32 && \ - data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_TILEN(data, coopLaunch, \ - LAUNCH_ESC(__nv_bfloat16, float, __nv_bfloat16, numExperts, extraFlag), kernel, \ - numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ - data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_TILEN(data, coopLaunch, \ - LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, float, numExperts, extraFlag), kernel, \ - numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ - data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_TILEN(data, coopLaunch, \ - LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, numExperts, extraFlag), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } else { \ - FLASHINFER_WARN("Unsupported dtypeExpW"); \ +#define LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, \ + smemSize, stream, extraFlag, numExperts, \ + numTopExperts) \ + if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Fp32 && \ + data.mDtypeExpW == tg::Dtype::Fp32) { \ + LAUNCH_TILEN(data, coopLaunch, \ + LAUNCH_ESC(float, float, float, numExperts, numTopExperts, extraFlag), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Fp32 && \ + data.mDtypeExpW == tg::Dtype::Bfloat16) { \ + LAUNCH_TILEN(data, coopLaunch, \ + LAUNCH_ESC(float, float, __nv_bfloat16, numExperts, numTopExperts, extraFlag), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ + data.mDtypeExpW == tg::Dtype::Fp32) { \ + LAUNCH_TILEN(data, coopLaunch, \ + LAUNCH_ESC(float, __nv_bfloat16, float, numExperts, numTopExperts, extraFlag), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ + data.mDtypeExpW == tg::Dtype::Bfloat16) { \ + LAUNCH_TILEN( \ + data, coopLaunch, \ + LAUNCH_ESC(float, __nv_bfloat16, __nv_bfloat16, numExperts, numTopExperts, extraFlag), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Fp32 && \ + data.mDtypeExpW == tg::Dtype::Fp32) { \ + LAUNCH_TILEN(data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, float, float, numExperts, numTopExperts, extraFlag), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Fp32 && \ + data.mDtypeExpW == tg::Dtype::Bfloat16) { \ + LAUNCH_TILEN( \ + data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, float, __nv_bfloat16, numExperts, numTopExperts, extraFlag), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ + data.mDtypeExpW == tg::Dtype::Fp32) { \ + LAUNCH_TILEN( \ + data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, float, numExperts, numTopExperts, extraFlag), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ + data.mDtypeExpW == tg::Dtype::Bfloat16) { \ + LAUNCH_TILEN(data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, numExperts, \ + numTopExperts, extraFlag), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } else { \ + FLASHINFER_WARN("Unsupported dtypeExpW"); \ } -#define LAUNCH_ROUTING_DEEPSEEK_IMPL(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag, numExperts) \ - if (extraFlag) { \ - LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, \ - smemSize, stream, true, numExperts); \ - } else { \ - LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, \ - smemSize, stream, false, numExperts); \ +#define LAUNCH_ROUTING_DEEPSEEK_IMPL(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + stream, extraFlag, numExperts, numTopExperts) \ + if (extraFlag) { \ + LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, \ + smemSize, stream, true, numExperts, numTopExperts); \ + } else { \ + LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, \ + smemSize, stream, false, numExperts, numTopExperts); \ } //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/flashinfer/trtllm/fused_moe/RoutingKernel.h b/include/flashinfer/trtllm/fused_moe/RoutingKernel.h index cae6729368..709fb57c0f 100644 --- a/include/flashinfer/trtllm/fused_moe/RoutingKernel.h +++ b/include/flashinfer/trtllm/fused_moe/RoutingKernel.h @@ -176,14 +176,15 @@ struct Data : public DataBase { bool mUseRoutingSoftmax; }; -template +template struct KernelParams : public KernelParamsBase { using InputT = InputT_; using BiasT = BiasT_; using OutputT = OutputT_; static constexpr bool UseGroups = UseGroups_; + static constexpr int MaxNumTopExperts = MaxNumTopExperts_; PackedScoreIdx* mPtrTopKPacked = nullptr; From 8ee2193e8c9933775c51f36237cee1ff8dd79013 Mon Sep 17 00:00:00 2001 From: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Date: Wed, 28 Jan 2026 09:58:33 +0000 Subject: [PATCH 13/25] Fix formatting Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com> --- csrc/trtllm_fused_moe_routing_deepseek.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/trtllm_fused_moe_routing_deepseek.cu b/csrc/trtllm_fused_moe_routing_deepseek.cu index ccb5209e40..422194cf7f 100644 --- a/csrc/trtllm_fused_moe_routing_deepseek.cu +++ b/csrc/trtllm_fused_moe_routing_deepseek.cu @@ -14,8 +14,8 @@ * limitations under the License. */ -#include #include +#include #include "flashinfer/exception.h" #include "flashinfer/trtllm/fused_moe/RoutingKernel.cuh" From c899d16517094814987ad9bacfb2fe255c96f43c Mon Sep 17 00:00:00 2001 From: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Date: Wed, 28 Jan 2026 09:58:33 +0000 Subject: [PATCH 14/25] Change TODO to comment Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com> --- csrc/trtllm_batched_gemm_runner.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/trtllm_batched_gemm_runner.cu b/csrc/trtllm_batched_gemm_runner.cu index 9982974953..f3eae5e9e3 100644 --- a/csrc/trtllm_batched_gemm_runner.cu +++ b/csrc/trtllm_batched_gemm_runner.cu @@ -223,7 +223,7 @@ void TrtllmGenBatchedGemmRunner::run( gemmData.mInputBuffers.mPtrSfB = mOptions.transposeMmaOutput ? sfA : sfB; gemmData.mInputBuffers.mPtrScaleC = scaleC; gemmData.mInputBuffers.mPtrScaleGate = scaleGateC; - // TODO amitz-nv: Do we want to pass scaleAct instead of using scaleGateC? + // For simplicity pass set scaleAct to scaleGateC gemmData.mInputBuffers.mPtrScaleAct = scaleGateC; gemmData.mInputBuffers.mPtrPerTokenSfA = mOptions.transposeMmaOutput ? perTokensSfB : perTokensSfA; From 0f6f15c5acf009ecd03c3c3ddf0c725409bed623 Mon Sep 17 00:00:00 2001 From: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Date: Wed, 28 Jan 2026 09:58:33 +0000 Subject: [PATCH 15/25] Change default activation_type to Swiglu Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com> --- flashinfer/fused_moe/core.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 124292d75b..0246ac34ce 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -1420,7 +1420,7 @@ def trtllm_fp8_per_tensor_scale_moe_op( routing_method_type: int = 0, enable_pdl: Optional[bool] = None, tune_max_num_tokens: int = 8192, - activation_type: ActivationType = ActivationType.Identity, + activation_type: ActivationType = ActivationType.Swiglu, ) -> torch.Tensor: if enable_pdl is None: enable_pdl = device_support_pdl(hidden_states.device) @@ -1531,7 +1531,7 @@ def _fake_trtllm_fp8_per_tensor_scale_moe( use_routing_scales_on_input: bool, routing_method_type: int = 0, enable_pdl: Optional[bool] = None, - activation_type: int = ActivationType.Identity.value, + activation_type: int = ActivationType.Swiglu.value, ): seq_len = hidden_states.shape[0] hidden_size = hidden_states.shape[1] @@ -2188,7 +2188,7 @@ def trtllm_fp8_per_tensor_scale_moe( routing_method_type: int = 0, enable_pdl: Optional[bool] = None, tune_max_num_tokens: int = 8192, - activation_type: int = ActivationType.Identity.value, + activation_type: int = ActivationType.Swiglu.value, ) -> torch.Tensor: """FP8 per tensor scale MoE operation. @@ -2213,7 +2213,7 @@ def trtllm_fp8_per_tensor_scale_moe( routing_method_type: Type of routing method to use (default: 0) enable_pdl: Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90. tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 8192) - activation_type (int): Type of activation function (default: 7 - Identity) + activation_type (int): Type of activation function (default: 3 - Swiglu) - 0: Gelu - 1: Relu - 2: Silu From cf6f76b5e8bc943551902b2891e73d20f9b22c52 Mon Sep 17 00:00:00 2001 From: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Date: Wed, 28 Jan 2026 10:00:18 +0000 Subject: [PATCH 16/25] Restore intermediate size factor of 2 for gated activation in getWorkspaceSizeInBytes, getDefaultValidConfigIndex, isValidConfigIndex Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com> --- csrc/trtllm_fused_moe_runner.cu | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/csrc/trtllm_fused_moe_runner.cu b/csrc/trtllm_fused_moe_runner.cu index 3db24f43c1..fe034dc914 100644 --- a/csrc/trtllm_fused_moe_runner.cu +++ b/csrc/trtllm_fused_moe_runner.cu @@ -302,8 +302,10 @@ size_t Runner::getWorkspaceSizeInBytes(int32_t topK, int32_t hiddenSize, int32_t int32_t configIndex) const { auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); - return mRunner.getWorkspaceSizeInBytes(numTokens, intermediateSize, hiddenSize, {}, numTokens, - numExperts, maxNumCtasInBatchDim, configIndex); + int32_t intermediateSizeFactor = (isGatedActivation(mActType) ? 2 : 1); + return mRunner.getWorkspaceSizeInBytes(numTokens, intermediateSizeFactor * intermediateSize, + hiddenSize, {}, numTokens, numExperts, + maxNumCtasInBatchDim, configIndex); } int32_t Runner::getDefaultValidConfigIndex(int32_t topK, int32_t hiddenSize, @@ -311,8 +313,10 @@ int32_t Runner::getDefaultValidConfigIndex(int32_t topK, int32_t hiddenSize, int32_t numTokens) const { auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); - return mRunner.getDefaultValidConfigIndex(numTokens, intermediateSize, hiddenSize, {}, - numTokens, numExperts, maxNumCtasInBatchDim); + int32_t intermediateSizeFactor = (isGatedActivation(mActType) ? 2 : 1); + return mRunner.getDefaultValidConfigIndex(numTokens, intermediateSizeFactor * intermediateSize, + hiddenSize, {}, numTokens, numExperts, + maxNumCtasInBatchDim); } bool Runner::isValidConfigIndex(int32_t configIndex, int32_t topK, int32_t hiddenSize, @@ -321,9 +325,10 @@ bool Runner::isValidConfigIndex(int32_t configIndex, int32_t topK, int32_t hidde auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); + int32_t intermediateSizeFactor = (isGatedActivation(mActType) ? 2 : 1); auto const isValid = - mRunner.isValidConfigIndex(configIndex, numTokens, intermediateSize, hiddenSize, {}, - numTokens, numExperts, maxNumCtasInBatchDim); + mRunner.isValidConfigIndex(configIndex, numTokens, intermediateSizeFactor * intermediateSize, + hiddenSize, {}, numTokens, numExperts, maxNumCtasInBatchDim); return isValid; } From e63e17d13de34646432cab77a72424f617849879 Mon Sep 17 00:00:00 2001 From: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Date: Wed, 28 Jan 2026 11:06:20 +0000 Subject: [PATCH 17/25] Formatting fixes Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com> --- csrc/trtllm_fused_moe_kernel_launcher.cu | 20 +++++++++--------- csrc/trtllm_fused_moe_runner.cu | 27 +++++++++++++----------- 2 files changed, 25 insertions(+), 22 deletions(-) diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index 0d61da7305..9663da2919 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -305,10 +305,9 @@ class FusedMoeLauncher { (int32_t)tile_tokens_dim, this->use_shuffled_weight, this->weight_layout); } else { - moe_runner = std::make_unique(this->mDtypeAct, this->mDtypeWeights, - args->mUseDeepSeekFp8, (int32_t)tile_tokens_dim, - this->activation_type, - this->use_shuffled_weight, this->weight_layout); + moe_runner = std::make_unique( + this->mDtypeAct, this->mDtypeWeights, args->mUseDeepSeekFp8, (int32_t)tile_tokens_dim, + this->activation_type, this->use_shuffled_weight, this->weight_layout); } if (moe_tactic == -1) { @@ -417,7 +416,8 @@ class Bf16MoeLauncher : public FusedMoeLauncher { void init(std::unique_ptr&& args, int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout) { - constexpr ActivationType activation_type = ActivationType::Swiglu; // not exposed in api for now + constexpr ActivationType activation_type = + ActivationType::Swiglu; // not exposed in api for now // Do base class init and perform common checks FusedMoeLauncher::init_common(std::move(args), tile_tokens_dim, routing_method_type, @@ -532,8 +532,8 @@ class Fp8PerTensorLauncher : public FusedMoeLauncher { void init(std::unique_ptr&& args, int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, - int64_t weight_layout, bool use_routing_scales_on_input_param, ActivationType activation_type) { - + int64_t weight_layout, bool use_routing_scales_on_input_param, + ActivationType activation_type) { this->use_routing_scales_on_input = use_routing_scales_on_input_param; auto dtype = hidden_states.dtype(); @@ -968,8 +968,7 @@ class MxInt4BlockScaleLauncher : public FusedMoeLauncher { FusedMoeLauncher::init_common( std::move(args), tile_tokens_dim, routing_method_type, /*use_shuffled_weight=*/true, - static_cast(batchedGemm::gemm::MatrixLayout::BlockMajorK), - ActivationType::Swiglu); + static_cast(batchedGemm::gemm::MatrixLayout::BlockMajorK), ActivationType::Swiglu); } void check_routing() const override { FusedMoeLauncher::check_routing_common(); } @@ -1763,7 +1762,8 @@ Array trtllm_fp4_block_scale_moe( gemm2_weights_scale, gemm2_bias, output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar, topk_ids, expert_weights); launcher->init(std::move(args), curr_tile_N, routing_method_type, /*use_shuffled_weight=*/true, - /*weight_layout=*/0, static_cast(act_type), mDtypeAct, mDtypeWeights); + /*weight_layout=*/0, static_cast(act_type), mDtypeAct, + mDtypeWeights); launchers_map[curr_tile_N] = std::move(launcher); } diff --git a/csrc/trtllm_fused_moe_runner.cu b/csrc/trtllm_fused_moe_runner.cu index fe034dc914..e3615fa1c4 100644 --- a/csrc/trtllm_fused_moe_runner.cu +++ b/csrc/trtllm_fused_moe_runner.cu @@ -201,7 +201,8 @@ static inline ActType activationTypeToGatedActType(ActivationType actType) { return ActType::GeGlu; default: FLASHINFER_CHECK(false, "Unsupported gated activation type ", - serializeActivationType(actType), " of enum ", static_cast(actType)); + serializeActivationType(actType), " of enum ", + static_cast(actType)); } return ActType::SwiGlu; } @@ -214,7 +215,8 @@ static inline EltwiseActType activationTypeToEltwiseActType(ActivationType actTy return EltwiseActType::None; default: FLASHINFER_CHECK(false, "Unsupported eltwise activation type ", - serializeActivationType(actType), " of enum ", static_cast(actType)); + serializeActivationType(actType), " of enum ", + static_cast(actType)); } return EltwiseActType::None; } @@ -224,8 +226,9 @@ tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions getOptions( ActivationType activationType, bool useShuffledMatrix, batchedGemm::gemm::MatrixLayout weightLayout) { int64_t actTypeInt = static_cast(activationType); - FLASHINFER_CHECK(0 <= actTypeInt && actTypeInt < static_cast(ActivationType::InvalidType), - "Unknown activation type", serializeActivationType(activationType), "of enum", actTypeInt); + FLASHINFER_CHECK( + 0 <= actTypeInt && actTypeInt < static_cast(ActivationType::InvalidType), + "Unknown activation type", serializeActivationType(activationType), "of enum", actTypeInt); bool isGatedAct = isGatedActivation(activationType); if (isGatedAct) { ActType actType = activationTypeToGatedActType(activationType); @@ -289,12 +292,13 @@ void Runner::run(void* hiddenState, void* hiddenStateScale, void* weights, void* auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); int32_t intermediateSizeFactor = (isGatedActivation(mActType) ? 2 : 1); - mRunner.run(numTokens, intermediateSizeFactor * intermediateSize, hiddenSize, {}, numTokens, numExperts, - maxNumCtasInBatchDim, hiddenState, hiddenStateScale, weights, weightsScale, - expertWeights, /* perTokensSfB */ nullptr, outputScalesScalar, outputScalesGateScalar, - ptrBias, ptrAlpha, ptrBeta, ptrClampLimit, output, outputScale, permutedIdxToTokenIdx, - ptrTotalNumPaddedTokens, ptrCtaIdxXyToBatchIdx, ptrCtaIdxXyToMnLimit, - ptrNumNonExitingCtas, bmm1Workspace, stream, device, configIndex, enable_pdl); + mRunner.run(numTokens, intermediateSizeFactor * intermediateSize, hiddenSize, {}, numTokens, + numExperts, maxNumCtasInBatchDim, hiddenState, hiddenStateScale, weights, + weightsScale, expertWeights, /* perTokensSfB */ nullptr, outputScalesScalar, + outputScalesGateScalar, ptrBias, ptrAlpha, ptrBeta, ptrClampLimit, output, + outputScale, permutedIdxToTokenIdx, ptrTotalNumPaddedTokens, ptrCtaIdxXyToBatchIdx, + ptrCtaIdxXyToMnLimit, ptrNumNonExitingCtas, bmm1Workspace, stream, device, + configIndex, enable_pdl); } size_t Runner::getWorkspaceSizeInBytes(int32_t topK, int32_t hiddenSize, int32_t intermediateSize, @@ -477,8 +481,7 @@ void Runner::setOpsData(MoERunnerArgs const& args, MoEWorkspace const& workspace activationData.inDqSfsPtr = workspace.gemm1_output_scale; activationData.outDqSfsPtr = workspace.activation_output_scale; activationData.innerDim = - args.intermediate_size * - (isGatedActivation(args.activation_type) ? 2 : 1); + args.intermediate_size * (isGatedActivation(args.activation_type) ? 2 : 1); activationData.topK = args.top_k; activationData.numTokens = args.num_tokens; activationData.expandedIdxToPermutedIdx = workspace.expanded_idx_to_permuted_idx; From 8398e20889448b2151e653f152b114254e09d0df Mon Sep 17 00:00:00 2001 From: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Date: Wed, 28 Jan 2026 15:14:11 +0000 Subject: [PATCH 18/25] Treat SwigluBias as gated activation Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com> --- include/flashinfer/trtllm/fused_moe/runner.h | 3 ++- tests/moe/utils.py | 6 +++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/include/flashinfer/trtllm/fused_moe/runner.h b/include/flashinfer/trtllm/fused_moe/runner.h index d8eb59b7be..46617e5dbd 100644 --- a/include/flashinfer/trtllm/fused_moe/runner.h +++ b/include/flashinfer/trtllm/fused_moe/runner.h @@ -174,7 +174,8 @@ inline std::string serializeActivationType(ActivationType activationType) { } inline bool isGatedActivation(ActivationType activationType) { - return activationType == ActivationType::Swiglu || activationType == ActivationType::Geglu; + return activationType == ActivationType::Swiglu || activationType == ActivationType::Geglu || + activationType == ActivationType::SwigluBias; } } // namespace MoE diff --git a/tests/moe/utils.py b/tests/moe/utils.py index 42f5c06418..8ff5cf82a2 100644 --- a/tests/moe/utils.py +++ b/tests/moe/utils.py @@ -40,7 +40,11 @@ class QuantMode(IntEnum): def is_gated_activation(activation_type: ActivationType) -> bool: - return activation_type in [ActivationType.Swiglu, ActivationType.Geglu] + return activation_type in [ + ActivationType.Swiglu, + ActivationType.Geglu, + ActivationType.SwigluBias, + ] def skip_checks( From ea67ceffc747f6466718abdbe37cb57c171decaf Mon Sep 17 00:00:00 2001 From: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Date: Wed, 28 Jan 2026 21:11:42 +0200 Subject: [PATCH 19/25] Fix use of ActivationType enum in CLI Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com> --- .../bench_trtllm_gen_fused_moe_autotuner.py | 5 +++-- .../routines/flashinfer_benchmark_utils.py | 17 +++++++++++++++++ benchmarks/routines/moe.py | 5 +++-- 3 files changed, 23 insertions(+), 4 deletions(-) diff --git a/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py b/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py index 6a2a9d6b53..ddf1c26188 100644 --- a/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py +++ b/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py @@ -17,6 +17,7 @@ from flashinfer.autotuner import autotune from flashinfer.testing.utils import bench_gpu_time from flashinfer.utils import device_support_pdl +from routines.flashinfer_benchmark_utils import enum_type FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max FLOAT4_E2M1_MAX = 6.0 @@ -356,11 +357,11 @@ def bench(do_autotune): ) parser.add_argument( "--activation-type", - type=ActivationType, + type=enum_type(ActivationType), choices=list(ActivationType), required=False, default=ActivationType.Swiglu, - help=f"Type of gated activation function: {list(ActivationType)}", + help=f"Type of gated activation function: {[e.name for e in ActivationType]}", ) args = parser.parse_args() if args.quant_mode in ["Fp8-Per-Tensor", "Fp8-Block"]: diff --git a/benchmarks/routines/flashinfer_benchmark_utils.py b/benchmarks/routines/flashinfer_benchmark_utils.py index b207f5cb43..9724809510 100644 --- a/benchmarks/routines/flashinfer_benchmark_utils.py +++ b/benchmarks/routines/flashinfer_benchmark_utils.py @@ -1,3 +1,4 @@ +import argparse import torch from flashinfer.testing.utils import set_seed @@ -453,3 +454,19 @@ def filter_backends_by_compute_capability(backends, routine, device): f"[WARNING] {backend} for routine {routine} is not supported on compute capability {compute_capability}. Skipping." ) return backends + + +def enum_type(enum_class): + """Generic factory for argparse enum types.""" + + def converter(value): + try: + formatted_value = value[0].upper() + value[1:].lower() + return enum_class[formatted_value] + except KeyError as e: + valid_options = [m.name for m in enum_class] + raise argparse.ArgumentTypeError( + f"Invalid value '{value}'. Must be one of: {', '.join(valid_options)}" + ) from e + + return converter diff --git a/benchmarks/routines/moe.py b/benchmarks/routines/moe.py index f00260d109..45f8ede7c3 100644 --- a/benchmarks/routines/moe.py +++ b/benchmarks/routines/moe.py @@ -24,6 +24,7 @@ from .flashinfer_benchmark_utils import ( dtype_str_to_torch_dtype, + enum_type, get_device, print_perf_metrics, filter_backends_by_compute_capability, @@ -177,11 +178,11 @@ def parse_moe_args(line, parser): ) parser.add_argument( "--activation-type", - type=ActivationType, + type=enum_type(ActivationType), choices=list(ActivationType), required=False, default=ActivationType.Swiglu, - help=f"Type of gated activation function: {list(ActivationType)}", + help=f"Type of gated activation function: {[e.name for e in ActivationType]}", ) parser.add_argument( "--autotune", From abefe2283a9327ca9ae5ca5e6f311f3057461820 Mon Sep 17 00:00:00 2001 From: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Date: Thu, 29 Jan 2026 10:17:55 +0000 Subject: [PATCH 20/25] Fix activation-type command line argument handling in benchmarks Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com> --- benchmarks/bench_trtllm_gen_fused_moe_autotuner.py | 4 ++-- benchmarks/routines/flashinfer_benchmark_utils.py | 7 +++---- benchmarks/routines/moe.py | 4 ++-- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py b/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py index ddf1c26188..0171afba2b 100644 --- a/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py +++ b/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py @@ -358,10 +358,10 @@ def bench(do_autotune): parser.add_argument( "--activation-type", type=enum_type(ActivationType), - choices=list(ActivationType), + choices=[e.name for e in ActivationType], required=False, default=ActivationType.Swiglu, - help=f"Type of gated activation function: {[e.name for e in ActivationType]}", + help=f"Type of activation function: {[e.name for e in ActivationType]}", ) args = parser.parse_args() if args.quant_mode in ["Fp8-Per-Tensor", "Fp8-Block"]: diff --git a/benchmarks/routines/flashinfer_benchmark_utils.py b/benchmarks/routines/flashinfer_benchmark_utils.py index 9724809510..375db471e4 100644 --- a/benchmarks/routines/flashinfer_benchmark_utils.py +++ b/benchmarks/routines/flashinfer_benchmark_utils.py @@ -461,12 +461,11 @@ def enum_type(enum_class): def converter(value): try: - formatted_value = value[0].upper() + value[1:].lower() - return enum_class[formatted_value] + lower_name_to_member = {m.name.lower(): m for m in enum_class} + return lower_name_to_member[value.lower()] except KeyError as e: - valid_options = [m.name for m in enum_class] raise argparse.ArgumentTypeError( - f"Invalid value '{value}'. Must be one of: {', '.join(valid_options)}" + f"Invalid value '{value}'. Must be one of: {', '.join([m.name for m in enum_class])}" ) from e return converter diff --git a/benchmarks/routines/moe.py b/benchmarks/routines/moe.py index 45f8ede7c3..bf117c33c2 100644 --- a/benchmarks/routines/moe.py +++ b/benchmarks/routines/moe.py @@ -179,10 +179,10 @@ def parse_moe_args(line, parser): parser.add_argument( "--activation-type", type=enum_type(ActivationType), - choices=list(ActivationType), + choices=[e.name for e in ActivationType], required=False, default=ActivationType.Swiglu, - help=f"Type of gated activation function: {[e.name for e in ActivationType]}", + help=f"Type of activation function: {[e.name for e in ActivationType]}", ) parser.add_argument( "--autotune", From da3576429527d1e86050a218a54e4bef52b6f8f4 Mon Sep 17 00:00:00 2001 From: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Date: Thu, 29 Jan 2026 10:33:32 +0000 Subject: [PATCH 21/25] Fix choices of activation-type command line argument handling in benchmarks Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com> --- benchmarks/bench_trtllm_gen_fused_moe_autotuner.py | 2 +- benchmarks/routines/moe.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py b/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py index 0171afba2b..532805f0bb 100644 --- a/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py +++ b/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py @@ -358,7 +358,7 @@ def bench(do_autotune): parser.add_argument( "--activation-type", type=enum_type(ActivationType), - choices=[e.name for e in ActivationType], + metavar=str([e.name for e in ActivationType]), required=False, default=ActivationType.Swiglu, help=f"Type of activation function: {[e.name for e in ActivationType]}", diff --git a/benchmarks/routines/moe.py b/benchmarks/routines/moe.py index bf117c33c2..a2e29098ab 100644 --- a/benchmarks/routines/moe.py +++ b/benchmarks/routines/moe.py @@ -179,7 +179,7 @@ def parse_moe_args(line, parser): parser.add_argument( "--activation-type", type=enum_type(ActivationType), - choices=[e.name for e in ActivationType], + metavar=str([e.name for e in ActivationType]), required=False, default=ActivationType.Swiglu, help=f"Type of activation function: {[e.name for e in ActivationType]}", From 205989f883f6724fc6fee14ab844fa268ba779a9 Mon Sep 17 00:00:00 2001 From: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Date: Thu, 29 Jan 2026 13:45:15 +0000 Subject: [PATCH 22/25] GEMM (non batched) still has mUseShuffledMatrixA member (with 'A' suffix) Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com> --- csrc/trtllm_low_latency_gemm_runner.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/trtllm_low_latency_gemm_runner.cu b/csrc/trtllm_low_latency_gemm_runner.cu index 99639d9687..f3ce0d43c3 100644 --- a/csrc/trtllm_low_latency_gemm_runner.cu +++ b/csrc/trtllm_low_latency_gemm_runner.cu @@ -166,7 +166,7 @@ class TrtllmLowLatencyGemmRunner { configOptions.mDtypeC == mOptions.outputType && configOptions.mTransposeMmaOutput == true && configOptions.mLayoutA == gemm::gemm::MatrixLayout::BlockMajorK && - configOptions.mUseShuffledMatrix) { + configOptions.mUseShuffledMatrixA) { mPassingConfigIndices.push_back(i); } } From e467f1d4ed30e54693815485d19ff4999cbb66e9 Mon Sep 17 00:00:00 2001 From: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Date: Thu, 29 Jan 2026 13:52:59 +0000 Subject: [PATCH 23/25] Update bench_trtllm_gen_fused_moe_autotuner.py to support more activations Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com> --- benchmarks/bench_trtllm_gen_fused_moe_autotuner.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py b/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py index 532805f0bb..79f15dc95e 100644 --- a/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py +++ b/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py @@ -99,9 +99,9 @@ def bench_trtllm_gen_fused_moe_autotuner_fp8( ) if is_block_scale: - if activation_type != ActivationType.Swiglu: + if activation_type == ActivationType.Relu2: raise ValueError( - "Only Swiglu activation is supported for FP8 block scale MoE." + "Relu2 activation is not supported for FP8 block scale MoE." ) fn = lambda: trtllm_fp8_block_scale_moe( routing_logits, @@ -182,6 +182,7 @@ def bench_trtllm_gen_fused_moe_autotuner_fp4( top_k: int, warmups: int, iterations: int, + activation_type: ActivationType, ): device = torch.device("cuda:0") enable_pdl = device_support_pdl(device) @@ -241,6 +242,10 @@ def bench_trtllm_gen_fused_moe_autotuner_fp4( w13_global_scale = 1.0 / 448.0 / 6.0 w2_global_scale = 1.0 / 448.0 / 6.0 else: + if activation_type == ActivationType.Relu2: + raise ValueError( + "Relu2 activation is supported for FP4 only with 'NvFP4xNvFP4' quant mode" + ) w13, w13_scale = fp4_quantize( w13, torch.tensor([1.0], device=device), sf_vec_size=32, sf_use_ue8m0=True ) @@ -295,7 +300,7 @@ def bench_trtllm_gen_fused_moe_autotuner_fp4( RoutingMethodType.Renormalize.value, True, enable_pdl, - ActivationType.Swiglu.value, # act_type + activation_type.value, # act_type None, num_tokens if tune_max_num_tokens is None else tune_max_num_tokens, ) @@ -388,4 +393,5 @@ def bench(do_autotune): args.top_k, args.warmups, args.iterations, + args.activation_type, ) From 80d1b5313e529e2c34c06e84461b7647cf61462d Mon Sep 17 00:00:00 2001 From: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Date: Thu, 29 Jan 2026 14:03:56 +0000 Subject: [PATCH 24/25] Revert activation_Type check in bench_trtllm_gen_fused_moe_autotuner.py for trtllm_fp8_block_scale_moe Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com> --- benchmarks/bench_trtllm_gen_fused_moe_autotuner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py b/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py index 79f15dc95e..8ff7036dec 100644 --- a/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py +++ b/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py @@ -99,9 +99,9 @@ def bench_trtllm_gen_fused_moe_autotuner_fp8( ) if is_block_scale: - if activation_type == ActivationType.Relu2: + if activation_type != ActivationType.Swiglu: raise ValueError( - "Relu2 activation is not supported for FP8 block scale MoE." + "Only Swiglu activation is supported for FP8 block scale MoE." ) fn = lambda: trtllm_fp8_block_scale_moe( routing_logits, From 21e0e08a438f3d6144dd10b2983746e606c883ab Mon Sep 17 00:00:00 2001 From: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Date: Thu, 29 Jan 2026 14:11:18 +0000 Subject: [PATCH 25/25] Include activation type in results in benchmarks/routings/moe.py Signed-off-by: amitz-nv <203509407+amitz-nv@users.noreply.github.com> --- benchmarks/routines/moe.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/benchmarks/routines/moe.py b/benchmarks/routines/moe.py index a2e29098ab..9ce48f5904 100644 --- a/benchmarks/routines/moe.py +++ b/benchmarks/routines/moe.py @@ -895,7 +895,7 @@ def run_fp4_moe( cur_res["use_routing_scales_on_input"] = args.use_routing_scales_on_input cur_res["input_dtype"] = input_dtype cur_res["weight_dtype"] = weight_dtype - cur_res["gated_act"] = args.gated_act + cur_res["activation_type"] = args.activation_type.name res.append(cur_res) return res @@ -1762,6 +1762,7 @@ def run_fp8_per_tensor_moe( cur_res["use_routing_scales_on_input"] = use_routing_scales_on_input cur_res["input_dtype"] = input_dtype cur_res["weight_dtype"] = weight_dtype + cur_res["activation_type"] = args.activation_type.name res.append(cur_res) return res