diff --git a/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py b/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py index 203faaff82..8ff7036dec 100644 --- a/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py +++ b/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py @@ -4,7 +4,7 @@ import numpy as np from flashinfer import ( RoutingMethodType, - GatedActType, + ActivationType, fp4_quantize, mxfp8_quantize, ) @@ -17,6 +17,7 @@ from flashinfer.autotuner import autotune from flashinfer.testing.utils import bench_gpu_time from flashinfer.utils import device_support_pdl +from routines.flashinfer_benchmark_utils import enum_type FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max FLOAT4_E2M1_MAX = 6.0 @@ -39,6 +40,7 @@ def bench_trtllm_gen_fused_moe_autotuner_fp8( top_k: int, warmups: int, iterations: int, + activation_type: ActivationType, ): device = torch.device("cuda:0") enable_pdl = device_support_pdl(device) @@ -97,6 +99,10 @@ def bench_trtllm_gen_fused_moe_autotuner_fp8( ) if is_block_scale: + if activation_type != ActivationType.Swiglu: + raise ValueError( + "Only Swiglu activation is supported for FP8 block scale MoE." + ) fn = lambda: trtllm_fp8_block_scale_moe( routing_logits, routing_bias, @@ -144,6 +150,7 @@ def bench_trtllm_gen_fused_moe_autotuner_fp8( RoutingMethodType.TopK.value, enable_pdl, num_tokens if tune_max_num_tokens is None else tune_max_num_tokens, + activation_type.value, ) def bench(do_autotune): @@ -175,6 +182,7 @@ def bench_trtllm_gen_fused_moe_autotuner_fp4( top_k: int, warmups: int, iterations: int, + activation_type: ActivationType, ): device = torch.device("cuda:0") enable_pdl = device_support_pdl(device) @@ -234,6 +242,10 @@ def bench_trtllm_gen_fused_moe_autotuner_fp4( w13_global_scale = 1.0 / 448.0 / 6.0 w2_global_scale = 1.0 / 448.0 / 6.0 else: + if activation_type == ActivationType.Relu2: + raise ValueError( + "Relu2 activation is supported for FP4 only with 'NvFP4xNvFP4' quant mode" + ) w13, w13_scale = fp4_quantize( w13, torch.tensor([1.0], device=device), sf_vec_size=32, sf_use_ue8m0=True ) @@ -288,7 +300,7 @@ def bench_trtllm_gen_fused_moe_autotuner_fp4( RoutingMethodType.Renormalize.value, True, enable_pdl, - GatedActType.SwiGlu.value, # gated_act_type + activation_type.value, # act_type None, num_tokens if tune_max_num_tokens is None else tune_max_num_tokens, ) @@ -348,6 +360,14 @@ def bench(do_autotune): parser.add_argument( "--iterations", type=int, default=100, help="Number of benchmark iterations" ) + parser.add_argument( + "--activation-type", + type=enum_type(ActivationType), + metavar=str([e.name for e in ActivationType]), + required=False, + default=ActivationType.Swiglu, + help=f"Type of activation function: {[e.name for e in ActivationType]}", + ) args = parser.parse_args() if args.quant_mode in ["Fp8-Per-Tensor", "Fp8-Block"]: bench_trtllm_gen_fused_moe_autotuner_fp8( @@ -360,6 +380,7 @@ def bench(do_autotune): args.top_k, args.warmups, args.iterations, + args.activation_type, ) else: bench_trtllm_gen_fused_moe_autotuner_fp4( @@ -372,4 +393,5 @@ def bench(do_autotune): args.top_k, args.warmups, args.iterations, + args.activation_type, ) diff --git a/benchmarks/routines/flashinfer_benchmark_utils.py b/benchmarks/routines/flashinfer_benchmark_utils.py index b207f5cb43..375db471e4 100644 --- a/benchmarks/routines/flashinfer_benchmark_utils.py +++ b/benchmarks/routines/flashinfer_benchmark_utils.py @@ -1,3 +1,4 @@ +import argparse import torch from flashinfer.testing.utils import set_seed @@ -453,3 +454,18 @@ def filter_backends_by_compute_capability(backends, routine, device): f"[WARNING] {backend} for routine {routine} is not supported on compute capability {compute_capability}. Skipping." ) return backends + + +def enum_type(enum_class): + """Generic factory for argparse enum types.""" + + def converter(value): + try: + lower_name_to_member = {m.name.lower(): m for m in enum_class} + return lower_name_to_member[value.lower()] + except KeyError as e: + raise argparse.ArgumentTypeError( + f"Invalid value '{value}'. Must be one of: {', '.join([m.name for m in enum_class])}" + ) from e + + return converter diff --git a/benchmarks/routines/moe.py b/benchmarks/routines/moe.py index 2e4dd7bf06..16c221f483 100644 --- a/benchmarks/routines/moe.py +++ b/benchmarks/routines/moe.py @@ -5,6 +5,7 @@ import torch import flashinfer +from flashinfer import ActivationType from flashinfer.autotuner import autotune from flashinfer.fused_moe import ( trtllm_fp4_block_scale_moe, @@ -21,6 +22,7 @@ from .flashinfer_benchmark_utils import ( dtype_str_to_torch_dtype, + enum_type, get_device, print_perf_metrics, filter_backends_by_compute_capability, @@ -170,12 +172,12 @@ def parse_moe_args(line, parser): help="Data type of the weights (before quantization).", ) parser.add_argument( - "--gated_act", - type=str, + "--activation-type", + type=enum_type(ActivationType), + metavar=str([e.name for e in ActivationType]), required=False, - default="swiglu", - choices=["swiglu", "geglu"], - help="Type of gated activation function: swiglu | geglu.", + default=ActivationType.Swiglu, + help=f"Type of activation function: {[e.name for e in ActivationType]}", ) parser.add_argument( "--autotune", @@ -242,13 +244,6 @@ def parse_moe_args(line, parser): } args.routing_method_type = routing_method_name_to_type[args.routing_method] - # Normalize gated act type (map string to internal int expected by kernels) - gated_act_name_to_type = { - "swiglu": 0, - "geglu": 1, - } - args.gated_act_type = gated_act_name_to_type[args.gated_act] - if args.verbose >= 1: print(f"[INFO] {args = }") return args @@ -451,7 +446,7 @@ def testTrtllmFp4BlockScaleMoe(args): use_shuffled_weight = args.use_shuffled_weight weight_layout = args.weight_layout is_cuda_graph_compatible = not args.no_cuda_graph - gated_act_type = args.gated_act_type + activation_type = args.activation_type res = [] backends = ["trtllm"] @@ -610,7 +605,7 @@ def run_fp4_moe( local_num_experts=local_num_experts, routed_scaling_factor=routed_scaling_factor, routing_method_type=routing_method_type, - gated_act_type=gated_act_type, + activation_type=activation_type.value, do_finalize=True, ) @@ -715,7 +710,7 @@ def run_fp4_moe( cur_res["use_routing_scales_on_input"] = args.use_routing_scales_on_input cur_res["input_dtype"] = input_dtype cur_res["weight_dtype"] = weight_dtype - cur_res["gated_act"] = args.gated_act + cur_res["activation_type"] = args.activation_type.name res.append(cur_res) return res @@ -1471,6 +1466,7 @@ def run_fp8_per_tensor_moe( output1_scales_gate_scalar, gemm2_weights_fp8, output2_scales_scalar, + activation_type, ): # Note: FP8 per-tensor MOE expects int64_t for n_group/topk_group, not Optional[int64_t] # So we convert None to 0 to indicate "no groups" mode @@ -1493,6 +1489,7 @@ def run_fp8_per_tensor_moe( routed_scaling_factor=routed_scaling_factor, use_routing_scales_on_input=use_routing_scales_on_input, routing_method_type=routing_method_type, + activation_type=activation_type.value, ) # Benchmark timing @@ -1513,6 +1510,7 @@ def run_fp8_per_tensor_moe( output1_scales_gate_scalar, gemm2_weights_fp8, output2_scales_scalar, + args.activation_type, ), ) @@ -1564,6 +1562,7 @@ def run_fp8_per_tensor_moe( cur_res["use_routing_scales_on_input"] = use_routing_scales_on_input cur_res["input_dtype"] = input_dtype cur_res["weight_dtype"] = weight_dtype + cur_res["activation_type"] = args.activation_type.name res.append(cur_res) return res diff --git a/csrc/trtllm_batched_gemm_runner.cu b/csrc/trtllm_batched_gemm_runner.cu index f99e766e86..f3eae5e9e3 100644 --- a/csrc/trtllm_batched_gemm_runner.cu +++ b/csrc/trtllm_batched_gemm_runner.cu @@ -101,14 +101,16 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner( options.mTransposeMmaOutput == mOptions.transposeMmaOutput && (!doesRouteImplUseNoRoute(options.mRouteImpl)) == mOptions.routeAct && options.mFusedAct == mOptions.fusedAct && options.mIsStaticBatch == mOptions.staticBatch && - tileSize == mOptions.tileSize && - options.mUseShuffledMatrix == mOptions.useShuffledMatrixA && + tileSize == mOptions.tileSize && options.mUseShuffledMatrix == mOptions.useShuffledMatrix && options.mLayoutA == mOptions.weightLayout) { if (options.mFusedAct) { if (options.mActType != static_cast(mOptions.actType)) { continue; } } + if ((int64_t)options.mEltwiseActType != (int64_t)mOptions.eltwiseActType) { + continue; + } if (mOptions.transposeMmaOutput && options.mEpilogueTileM == mOptions.epilogueTileM) { mPassingConfigIndices.push_back(i); @@ -122,6 +124,8 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner( << ", mDtypeB: " << tg::dtypeToString(mOptions.dtypeB) << ", mDtypeC: " << tg::dtypeToString(mOptions.dtypeC) << ", mUseDeepSeekFp8: " << mOptions.deepSeekFp8 + << ", mActType: " << (int64_t)mOptions.actType + << ", mEltwiseActType: " << (int64_t)mOptions.eltwiseActType << ", mTransposeMmaOutput: " << mOptions.transposeMmaOutput << ", mRouteAct: " << mOptions.routeAct << ", mFusedAct: " << mOptions.fusedAct << ", mIsStaticBatch: " << mOptions.staticBatch << ", mTileSize: " << mOptions.tileSize; @@ -219,6 +223,8 @@ void TrtllmGenBatchedGemmRunner::run( gemmData.mInputBuffers.mPtrSfB = mOptions.transposeMmaOutput ? sfA : sfB; gemmData.mInputBuffers.mPtrScaleC = scaleC; gemmData.mInputBuffers.mPtrScaleGate = scaleGateC; + // For simplicity pass set scaleAct to scaleGateC + gemmData.mInputBuffers.mPtrScaleAct = scaleGateC; gemmData.mInputBuffers.mPtrPerTokenSfA = mOptions.transposeMmaOutput ? perTokensSfB : perTokensSfA; gemmData.mInputBuffers.mPtrPerTokenSfB = diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index cecf4efb7a..2b0efb7cc9 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -36,7 +36,7 @@ namespace flashinfer { namespace btg = batchedGemm::trtllm::gen; -using tensorrt_llm::kernels::trtllmgen_moe::MoE::GatedActType; +using tensorrt_llm::kernels::trtllmgen_moe::MoE::ActivationType; using tensorrt_llm::kernels::trtllmgen_moe::Routing::RoutingMethodType; using tvm::ffi::Array; using tvm::ffi::Optional; @@ -109,7 +109,7 @@ class FusedMoeLauncher { btg::Dtype mDtypeWeights{btg::Dtype::Bfloat16}; btg::Dtype mRoutingBiasDtype{ btg::Dtype::Bfloat16}; // Dtype for expert weights in routing, based on routing bias - GatedActType gated_act_type{GatedActType::SwiGlu}; + ActivationType activation_type{ActivationType::Swiglu}; public: // Constructor that initializes all TensorView members @@ -134,14 +134,14 @@ class FusedMoeLauncher { weight_layout{batchedGemm::gemm::MatrixLayout::MajorK}, mDtypeAct{btg::Dtype::Bfloat16}, mDtypeWeights{btg::Dtype::Bfloat16}, - gated_act_type{GatedActType::SwiGlu} {} + activation_type{ActivationType::Swiglu} {} protected: // Initialize common data necessary for later. // May throw exception from TVM_FFI_ICHECK. void init_common(std::unique_ptr&& args, int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, - int64_t weight_layout, int64_t gated_act_type); + int64_t weight_layout, ActivationType activation_type); // Routing logits [num_tokens, num_experts] void check_routing_logits_shape() const { @@ -305,10 +305,9 @@ class FusedMoeLauncher { (int32_t)tile_tokens_dim, this->use_shuffled_weight, this->weight_layout); } else { - moe_runner = std::make_unique(this->mDtypeAct, this->mDtypeWeights, - args->mUseDeepSeekFp8, (int32_t)tile_tokens_dim, - static_cast(this->gated_act_type), - this->use_shuffled_weight, this->weight_layout); + moe_runner = std::make_unique( + this->mDtypeAct, this->mDtypeWeights, args->mUseDeepSeekFp8, (int32_t)tile_tokens_dim, + this->activation_type, this->use_shuffled_weight, this->weight_layout); } if (moe_tactic == -1) { @@ -377,7 +376,7 @@ class FusedMoeLauncher { void FusedMoeLauncher::init_common( std::unique_ptr&& args, int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, - int64_t weight_layout, int64_t gated_act_type) { + int64_t weight_layout, ActivationType activation_type) { // Check devicearchitecture: Blackwell (SM 10.x) required auto device = hidden_states.device().device_id; int major = 0, minor = 0; @@ -400,9 +399,7 @@ void FusedMoeLauncher::init_common( TVM_FFI_ICHECK(0 <= weight_layout && weight_layout <= 2) << "the value of weight_layout is not recognized"; this->weight_layout = static_cast(weight_layout); - TVM_FFI_ICHECK(0 <= gated_act_type && gated_act_type <= 1) - << "the value of gated_act_type is not recognized"; - this->gated_act_type = static_cast(gated_act_type); + this->activation_type = activation_type; } class Bf16MoeLauncher : public FusedMoeLauncher { @@ -419,12 +416,12 @@ class Bf16MoeLauncher : public FusedMoeLauncher { void init(std::unique_ptr&& args, int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout) { - constexpr int64_t gated_act_type = - static_cast(GatedActType::SwiGlu); // not exposed in api for now + constexpr ActivationType activation_type = + ActivationType::Swiglu; // not exposed in api for now // Do base class init and perform common checks FusedMoeLauncher::init_common(std::move(args), tile_tokens_dim, routing_method_type, - use_shuffled_weight, weight_layout, gated_act_type); + use_shuffled_weight, weight_layout, activation_type); } void check_routing() const override { @@ -489,7 +486,7 @@ class Bf16MoeLauncher : public FusedMoeLauncher { static Array> getValidConfigs(int64_t top_k, int64_t hidden_size, int64_t intermediate_size, int64_t num_local_experts, - int64_t num_tokens, int64_t gated_act_type, + int64_t num_tokens, int64_t act_type, bool use_shuffled_weight, int64_t weight_layout) { Array> valid_configs; @@ -502,7 +499,7 @@ class Bf16MoeLauncher : public FusedMoeLauncher { btg::Dtype::Bfloat16, // dtype_act btg::Dtype::Bfloat16, // dtype_weights false, // useDeepSeekFp8 - tile_N, static_cast(gated_act_type), use_shuffled_weight, + tile_N, static_cast(act_type), use_shuffled_weight, static_cast(weight_layout)); auto cfgs = moe_runner->getValidConfigIndices(top_k, hidden_size, intermediate_size, @@ -535,10 +532,8 @@ class Fp8PerTensorLauncher : public FusedMoeLauncher { void init(std::unique_ptr&& args, int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, - int64_t weight_layout, bool use_routing_scales_on_input_param) { - constexpr int64_t gated_act_type = - static_cast(GatedActType::SwiGlu); // not exposed in api for now - + int64_t weight_layout, bool use_routing_scales_on_input_param, + ActivationType activation_type) { this->use_routing_scales_on_input = use_routing_scales_on_input_param; auto dtype = hidden_states.dtype(); @@ -554,7 +549,7 @@ class Fp8PerTensorLauncher : public FusedMoeLauncher { mDtypeWeights = btg::Dtype::E4m3; FusedMoeLauncher::init_common(std::move(args), tile_tokens_dim, routing_method_type, - use_shuffled_weight, weight_layout, gated_act_type); + use_shuffled_weight, weight_layout, activation_type); } void check_routing() const override { FusedMoeLauncher::check_routing_common(); } @@ -682,7 +677,7 @@ class Fp8PerTensorLauncher : public FusedMoeLauncher { public: static Array> getValidConfigs(int64_t top_k, int64_t hidden_size, int64_t intermediate_size, int64_t num_local_experts, - int64_t num_tokens, int64_t gated_act_type, + int64_t num_tokens, int64_t act_type, bool use_shuffled_weight, int64_t weight_layout, btg::Dtype dtype_act, btg::Dtype dtype_weights) { Array> valid_configs; @@ -695,7 +690,7 @@ class Fp8PerTensorLauncher : public FusedMoeLauncher { auto moe_runner = std::make_unique( dtype_act, dtype_weights, false, // useDeepSeekFp8 - tile_N, static_cast(gated_act_type), use_shuffled_weight, + tile_N, static_cast(act_type), use_shuffled_weight, static_cast(weight_layout)); auto cfgs = moe_runner->getValidConfigIndices(top_k, hidden_size, intermediate_size, @@ -732,7 +727,7 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { void init(std::unique_ptr&& args, int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout) { - constexpr int64_t gated_act_type = static_cast(GatedActType::SwiGlu); + constexpr ActivationType activation_type = ActivationType::Swiglu; mDtypeAct = btg::Dtype::E4m3; mDtypeWeights = btg::Dtype::E4m3; @@ -752,7 +747,7 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { args->mDtypeOut = btg::Dtype::Bfloat16; FusedMoeLauncher::init_common(std::move(args), tile_tokens_dim, routing_method_type, - use_shuffled_weight, weight_layout, gated_act_type); + use_shuffled_weight, weight_layout, activation_type); } void check_routing() const override { @@ -1049,8 +1044,7 @@ class MxInt4BlockScaleLauncher : public FusedMoeLauncher { FusedMoeLauncher::init_common( std::move(args), tile_tokens_dim, routing_method_type, /*use_shuffled_weight=*/true, - static_cast(batchedGemm::gemm::MatrixLayout::BlockMajorK), - static_cast(GatedActType::SwiGlu)); + static_cast(batchedGemm::gemm::MatrixLayout::BlockMajorK), ActivationType::Swiglu); } void check_routing() const override { FusedMoeLauncher::check_routing_common(); } @@ -1153,8 +1147,8 @@ class MxInt4BlockScaleLauncher : public FusedMoeLauncher { auto moe_runner = std::make_unique( btg::Dtype::Bfloat16, btg::Dtype::MxInt4, false, // useDeepSeekFp8 - tile_N, GatedActType::SwiGlu, - /*useShuffledMatrixA*/ true, batchedGemm::gemm::MatrixLayout::BlockMajorK); + tile_N, ActivationType::Swiglu, + /*useShuffledMatrix*/ true, batchedGemm::gemm::MatrixLayout::BlockMajorK); auto cfgs = moe_runner->getValidConfigIndices(top_k, hidden_size, intermediate_size, num_local_experts, num_tokens); @@ -1208,7 +1202,7 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { void init(std::unique_ptr&& args, int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, - int64_t weight_layout, int64_t gated_act_type, btg::Dtype dtype_act, + int64_t weight_layout, ActivationType activation_type, btg::Dtype dtype_act, btg::Dtype dtype_weights) { static const std::tuple device_props = [this] { int major, minor; @@ -1232,7 +1226,7 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { mDtypeWeights = dtype_weights; FusedMoeLauncher::init_common(std::move(args), tile_tokens_dim, routing_method_type, - use_shuffled_weight, weight_layout, gated_act_type); + use_shuffled_weight, weight_layout, activation_type); } void check_routing() const override { @@ -1452,7 +1446,7 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { static Array> getValidConfigs(int64_t top_k, int64_t hidden_size, int64_t intermediate_size, int64_t num_local_experts, - int64_t num_tokens, int64_t gated_act_type, + int64_t num_tokens, int64_t act_type, btg::Dtype dtype_act, btg::Dtype dtype_weights) { Array> valid_configs; @@ -1464,8 +1458,8 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { auto moe_runner = std::make_unique( dtype_act, dtype_weights, false, // useDeepSeekFp8 - tile_N, static_cast(gated_act_type), - /*useShuffledMatrixA*/ true); // FP4 uses shuffled weights + tile_N, static_cast(act_type), + /*useShuffledMatrix*/ true); // FP4 uses shuffled weights auto cfgs = moe_runner->getValidConfigIndices(top_k, hidden_size, intermediate_size, num_local_experts, num_tokens); @@ -1558,9 +1552,10 @@ Tensor trtllm_fp8_per_tensor_scale_moe( Optional n_group, Optional topk_group, int64_t intermediate_size, int64_t local_expert_offset, int64_t local_num_experts, Optional routed_scaling_factor, bool use_routing_scales_on_input, int64_t routing_method_type, bool enable_pdl, - Array config_index) { + Array config_index, int64_t activation_type) { // Basic type validation auto dtype = hidden_states.dtype(); + auto activation = static_cast(activation_type); if (use_routing_scales_on_input) { TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_bfloat16) << "routing_logits must be bfloat16."; } else if (static_cast(routing_method_type) == RoutingMethodType::DeepSeekV3) { @@ -1585,7 +1580,7 @@ Tensor trtllm_fp8_per_tensor_scale_moe( auto const hidden_size = hidden_states.size(1); // Use default values that match the original function behavior - bool use_shuffled_weight = true; // Original uses /*useShuffledMatrixA*/ true + bool use_shuffled_weight = true; // Original uses /*useShuffledMatrix*/ true int64_t weight_layout = 0; // Default to MajorK // Calculate supported tile sizes @@ -1617,7 +1612,7 @@ Tensor trtllm_fp8_per_tensor_scale_moe( routing_logits, routing_bias, hidden_states, gemm1_weights, output1_scales_scalar, output1_scales_gate_scalar, gemm2_weights, output2_scales_scalar); launcher->init(std::move(args), curr_tile_N, routing_method_type, use_shuffled_weight, - weight_layout, use_routing_scales_on_input); + weight_layout, use_routing_scales_on_input, activation); launchers_map[curr_tile_N] = std::move(launcher); } @@ -1750,7 +1745,7 @@ Array trtllm_fp4_block_scale_moe( Optional output2_scales_scalar, int64_t num_experts, int64_t top_k, Optional n_group, Optional topk_group, int64_t intermediate_size, int64_t local_expert_offset, int64_t local_num_experts, Optional routed_scaling_factor, - int64_t routing_method_type, bool do_finalize, bool enable_pdl, int64_t gated_act_type, + int64_t routing_method_type, bool do_finalize, bool enable_pdl, int64_t act_type, TensorView output, Array config_index) { // Determine data types based on input format int const num_tokens = hidden_states.size(0); @@ -1761,8 +1756,11 @@ Array trtllm_fp4_block_scale_moe( if (hidden_states_scale.has_value()) { hidden_states_scale_vec_size = (num_tokens * hidden_size) / hidden_states_scale.value().numel(); } + int64_t intermediate_size_factor = + isGatedActivation(static_cast(act_type)) ? 2 : 1; int weight_scale_vec_size = - (local_num_experts * intermediate_size * 2 * hidden_size) / gemm1_weights_scale.numel(); + (local_num_experts * intermediate_size * intermediate_size_factor * hidden_size) / + gemm1_weights_scale.numel(); TVM_FFI_ICHECK(weight_scale_vec_size == 16 || weight_scale_vec_size == 32) << "unsupported weight_scale_vec_size."; @@ -1855,7 +1853,8 @@ Array trtllm_fp4_block_scale_moe( gemm2_weights_scale, gemm2_bias, output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar, expert_indices, expert_weights); launcher->init(std::move(args), curr_tile_N, routing_method_type, /*use_shuffled_weight=*/true, - /*weight_layout=*/0, gated_act_type, mDtypeAct, mDtypeWeights); + /*weight_layout=*/0, static_cast(act_type), mDtypeAct, + mDtypeWeights); launchers_map[curr_tile_N] = std::move(launcher); } @@ -1968,7 +1967,7 @@ Array trtllm_mxint4_block_scale_moe( Array> trtllm_get_valid_moe_configs( int64_t const dtype_act_, int64_t const dtype_weights_, bool const useDeepSeekFp8, int64_t const top_k, int64_t const hidden_size, int64_t const intermediate_size, - int64_t const num_local_experts, int64_t const gated_act_type, bool const use_shuffled_weight, + int64_t const num_local_experts, int64_t const act_type, bool const use_shuffled_weight, int64_t const weight_layout, int64_t const num_tokens) { auto dtype_act = static_cast(dtype_act_); auto dtype_weights = static_cast(dtype_weights_); @@ -1981,7 +1980,7 @@ Array> trtllm_get_valid_moe_configs( if (dtype_act == btg::Dtype::Bfloat16 && dtype_weights == btg::Dtype::Bfloat16) { // BF16 MoE return Bf16MoeLauncher::getValidConfigs(top_k, hidden_size, intermediate_size, - num_local_experts, num_tokens, gated_act_type, + num_local_experts, num_tokens, act_type, use_shuffled_weight, weight_layout); } else if (dtype_act == btg::Dtype::E4m3 && dtype_weights == btg::Dtype::E4m3) { @@ -1989,7 +1988,7 @@ Array> trtllm_get_valid_moe_configs( if (!useDeepSeekFp8) { // FP8 per-tensor scale return Fp8PerTensorLauncher::getValidConfigs( - top_k, hidden_size, intermediate_size, num_local_experts, num_tokens, gated_act_type, + top_k, hidden_size, intermediate_size, num_local_experts, num_tokens, act_type, use_shuffled_weight, weight_layout, dtype_act, dtype_weights); } else { // FP8 block scale @@ -2000,7 +1999,7 @@ Array> trtllm_get_valid_moe_configs( } else if (dtype_weights == btg::Dtype::E2m1 || dtype_weights == btg::Dtype::MxE2m1) { // FP4 block scale return FP4BlockScaleLauncher::getValidConfigs(top_k, hidden_size, intermediate_size, - num_local_experts, num_tokens, gated_act_type, + num_local_experts, num_tokens, act_type, dtype_act, dtype_weights); } diff --git a/csrc/trtllm_fused_moe_routing_deepseek.cu b/csrc/trtllm_fused_moe_routing_deepseek.cu index 21faec8ec7..99981e7b63 100644 --- a/csrc/trtllm_fused_moe_routing_deepseek.cu +++ b/csrc/trtllm_fused_moe_routing_deepseek.cu @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include #include "flashinfer/exception.h" @@ -25,10 +26,14 @@ namespace routingDeepSeek { //////////////////////////////////////////////////////////////////////////////////////////////////// +static constexpr int NumNemotronExperts = 512; static constexpr int NumKimiK2Experts = 384; static constexpr int NumDeepseekExperts = 256; +static constexpr int MaxSupportedExpertCount = + std::max({NumNemotronExperts, NumKimiK2Experts, NumDeepseekExperts}); static constexpr int NumTopGroupScores = 2; -static constexpr int MaxNumTopExperts = 8; +static constexpr int DefaultMaxNumTopExperts = 8; +static constexpr int MaxSupportedTopExperts = 22; static constexpr int MaxNumTopGroups = 4; static constexpr int MaxNumGroups = 8; @@ -117,8 +122,8 @@ __global__ void routingMainKernel(KernelParams params) { int32_t topGroupIdx[MaxNumTopGroups]; float expertScoreGroup[MaxNumTopGroups]; int32_t expertIdxGroup[MaxNumTopGroups]; - float topScores[MaxNumTopExperts]; // bound of params.mTopK - int32_t topExperts[MaxNumTopExperts]; + float topScores[KernelParams::MaxNumTopExperts]; // bound of params.mTopK + int32_t topExperts[KernelParams::MaxNumTopExperts]; if constexpr (KernelParams::UseGroups) { topk::reduceTopK(warp, topExpGroupScores, topExpGroupIdx, scoreBias, threadExpert, @@ -154,7 +159,8 @@ __global__ void routingMainKernel(KernelParams params) { // params.mNumExpertsPerGroup // => expertIdxGroup[ii] < params.mNumExperts <= NumThreads, // so the access is safe here - expertScoreGroup[ii] = groupIdx < params.mNumExpertGroups && expertSelected + expertScoreGroup[ii] = (ii < params.mNumLimitedGroups) && + (groupIdx < params.mNumExpertGroups) && expertSelected ? smemScoreBias[expertIdxGroup[ii]] : invalidScoreFloat; } @@ -166,7 +172,7 @@ __global__ void routingMainKernel(KernelParams params) { // without groups, each thread just takes `MaxNumTopGroups` experts int constexpr NumExpertWarps = (KernelParams::MaxNumExperts - 1) / topk::MaxNumExpertsUnit + 1; - int constexpr NumInterTopK = NumExpertWarps * MaxNumTopExperts; + int constexpr NumInterTopK = NumExpertWarps * KernelParams::MaxNumTopExperts; __shared__ float __attribute((aligned(128))) smemInterTopScores[NumInterTopK]; __shared__ int32_t __attribute((aligned(128))) smemInterTopExperts[NumInterTopK]; if (warpIdx < NumExpertWarps) { @@ -183,13 +189,20 @@ __global__ void routingMainKernel(KernelParams params) { /* minValue */ invalidScoreFloat, params.mTopK); if (laneIdx < params.mTopK) { - smemInterTopScores[warpIdx * MaxNumTopExperts + laneIdx] = topScores[laneIdx]; - smemInterTopExperts[warpIdx * MaxNumTopExperts + laneIdx] = topExperts[laneIdx]; + smemInterTopScores[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] = + topScores[laneIdx]; + smemInterTopExperts[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] = + topExperts[laneIdx]; + } else if (laneIdx >= params.mTopK && laneIdx < KernelParams::MaxNumTopExperts) { + smemInterTopScores[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] = + invalidScoreFloat; + smemInterTopExperts[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] = + MaxSupportedExpertCount - 1; } } __syncthreads(); if (warpIdx == 0) { - int constexpr NumInterTopKPerThread = (NumInterTopK * NumExpertWarps - 1) / WarpSize + 1; + int constexpr NumInterTopKPerThread = (NumInterTopK - 1) / WarpSize + 1; float intermidiateScore[NumInterTopKPerThread]; int32_t intermidiateExpert[NumInterTopKPerThread]; for (int i = laneIdx; i < NumInterTopKPerThread * WarpSize; i += WarpSize) { @@ -270,7 +283,7 @@ __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) cudaGridDependencySynchronize(); } routingPermutation(params, nullptr, warpIdx, clusterBlockRank); } #else @@ -493,6 +506,8 @@ int constexpr getMaxNumExperts(int32_t numExperts) { return NumDeepseekExperts; } else if (numExperts <= NumKimiK2Experts) { return NumKimiK2Experts; + } else if (numExperts <= NumNemotronExperts) { + return NumNemotronExperts; } else { TLLM_LOG_ERROR("Unsupported numExperts"); return 0; @@ -504,13 +519,23 @@ int constexpr getMaxNumExperts(int32_t numExperts) { extraFlag) \ if (data.mNumExperts <= topk::MaxNumExpertsUnit) { \ LAUNCH_ROUTING_DEEPSEEK_IMPL(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag, topk::MaxNumExpertsUnit); \ + stream, extraFlag, topk::MaxNumExpertsUnit, \ + DefaultMaxNumTopExperts); \ } else if (data.mNumExperts <= NumDeepseekExperts) { \ LAUNCH_ROUTING_DEEPSEEK_IMPL(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag, NumDeepseekExperts); \ + stream, extraFlag, NumDeepseekExperts, DefaultMaxNumTopExperts); \ } else if (data.mNumExperts <= NumKimiK2Experts) { \ LAUNCH_ROUTING_DEEPSEEK_IMPL(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag, NumKimiK2Experts); \ + stream, extraFlag, NumKimiK2Experts, DefaultMaxNumTopExperts); \ + } else if (data.mNumExperts <= NumNemotronExperts) { \ + if (data.mTopK <= DefaultMaxNumTopExperts) { \ + LAUNCH_ROUTING_DEEPSEEK_IMPL(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + stream, extraFlag, NumNemotronExperts, \ + DefaultMaxNumTopExperts); \ + } else { \ + LAUNCH_ROUTING_DEEPSEEK_IMPL(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + stream, extraFlag, NumNemotronExperts, MaxSupportedTopExperts); \ + } \ } else { \ TLLM_LOG_ERROR("Unsupported numExperts"); \ } @@ -532,20 +557,29 @@ void runImpl(Data& data, void* stream) { FLASHINFER_CHECK(data.mNumLimitedGroups <= MaxNumTopGroups, "Routing kernel expects <= %d top groups, got %d", MaxNumTopGroups, data.mNumLimitedGroups); - FLASHINFER_CHECK(data.mTopK <= MaxNumTopExperts, - "Routing kernel expects topK experts <= %d, got %d", MaxNumTopExperts, - data.mTopK); + // Test limits according to values passed in launch, see definition of LAUNCH_ROUTING_DEEPSEEK + if (data.mNumExperts <= NumKimiK2Experts) { + FLASHINFER_CHECK( + data.mTopK <= DefaultMaxNumTopExperts, + "When NumExperts <= NumKimiK2Experts, routing kernel expects topK experts <= %d, got %d", + DefaultMaxNumTopExperts, data.mTopK); + } else { + FLASHINFER_CHECK( + data.mTopK <= MaxSupportedTopExperts, + "When NumExperts > NumKimiK2Experts, routing kernel expects topK experts <= %d, got %d", + MaxSupportedTopExperts, data.mTopK); + } FLASHINFER_CHECK(data.mTopK <= WarpSize, "Routing kernel expects top K <= warp size, got %d", data.mTopK); FLASHINFER_CHECK(data.mTopK * data.mNumLimitedGroups <= WarpSize, "Routing kernel expects top K * top groups <= warp size (for now), got %d * %d", data.mTopK, data.mNumLimitedGroups); - FLASHINFER_CHECK(data.mNumExperts >= MaxNumTopExperts, - "Routing kernel expects %d to be at most #experts %d", MaxNumTopExperts, + FLASHINFER_CHECK(data.mTopK <= data.mNumExperts, + "Routing kernel expects topK %d to be at most #experts %d", data.mTopK, data.mNumExperts); - FLASHINFER_CHECK(data.mNumExperts <= NumKimiK2Experts, + FLASHINFER_CHECK(data.mNumExperts <= MaxSupportedExpertCount, "Routing kernel expects #experts %d <= #threads %d", data.mNumExperts, - NumKimiK2Experts); + MaxSupportedExpertCount); FLASHINFER_CHECK(data.mNumExpertGroups >= data.mNumLimitedGroups, "Routing kernel expects top groups %d to be limited by #expert groups %d", data.mNumLimitedGroups, data.mNumExpertGroups); @@ -560,10 +594,6 @@ void runImpl(Data& data, void* stream) { data.mNumExperts / data.mNumExpertGroups <= WarpSize, "Routing kernel expects #experts per group <= warp size, got %d, data.mNumExpertGroups %d", data.mNumExperts / data.mNumExpertGroups, data.mNumExpertGroups); - } else { - FLASHINFER_CHECK(data.mTopK <= topk::MaxNumTopK, - "Routing kernel expects top K %d to be <= #warps %d", data.mTopK, - topk::MaxNumTopK); } FLASHINFER_CHECK(data.mNumExperts % 4 == 0, "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts); @@ -598,7 +628,7 @@ void runImpl(Data& data, void* stream) { int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK; if (data.mPtrTopKIds == nullptr) { int const numThreadsMain = - data.mNumExperts < NumDeepseekExperts ? NumDeepseekExperts : NumKimiK2Experts; + max(data.mNumExpertGroups * WarpSize, getMaxNumExperts(data.mNumExperts)); LAUNCH_ROUTING_DEEPSEEK(data, /*coopLaunch=*/false, routingMainKernel, numBlocks, numThreadsMain, /*smemSize=*/0, // No dynamic smem diff --git a/csrc/trtllm_fused_moe_runner.cu b/csrc/trtllm_fused_moe_runner.cu index b5ff5757c9..e3615fa1c4 100644 --- a/csrc/trtllm_fused_moe_runner.cu +++ b/csrc/trtllm_fused_moe_runner.cu @@ -60,7 +60,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 bool useRoutingScalesOnInput, bool useDeepSeekFp8, RoutingMethodType routingMethodType, cudaStream_t stream) { if (routingMethodType == RoutingMethodType::DeepSeekV3) { - FLASHINFER_CHECK(topK <= 8, "For DeepSeek routing method, must have topK <= 8"); + FLASHINFER_CHECK(topK <= 22, "For DeepSeek routing method, must have topK <= 22"); FLASHINFER_CHECK(topkGroup <= 4, "For DeepSeek routing method, must have topkGroup <= 4"); moe::dev::routing::routingDeepSeek::Data routingData; routingData.mDtypeExpW = @@ -189,13 +189,49 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 namespace PermuteGemm1 { +using tensorrt_llm::kernels::trtllmgen_moe::MoE::ActivationType; +using tensorrt_llm::kernels::trtllmgen_moe::MoE::isGatedActivation; +using tensorrt_llm::kernels::trtllmgen_moe::MoE::serializeActivationType; + +static inline ActType activationTypeToGatedActType(ActivationType actType) { + switch (actType) { + case ActivationType::Swiglu: + return ActType::SwiGlu; + case ActivationType::Geglu: + return ActType::GeGlu; + default: + FLASHINFER_CHECK(false, "Unsupported gated activation type ", + serializeActivationType(actType), " of enum ", + static_cast(actType)); + } + return ActType::SwiGlu; +} + +static inline EltwiseActType activationTypeToEltwiseActType(ActivationType actType) { + switch (actType) { + case ActivationType::Relu2: + return EltwiseActType::Relu2; + case ActivationType::Identity: + return EltwiseActType::None; + default: + FLASHINFER_CHECK(false, "Unsupported eltwise activation type ", + serializeActivationType(actType), " of enum ", + static_cast(actType)); + } + return EltwiseActType::None; +} + tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions getOptions( btg::Dtype dtypeAct, btg::Dtype dtypeWeights, int32_t tileTokensDim, bool useDeepSeekFp8, - MoE::GatedActType gatedActType, bool useShuffledMatrixA, + ActivationType activationType, bool useShuffledMatrix, batchedGemm::gemm::MatrixLayout weightLayout) { - if (gatedActType == MoE::GatedActType::SwiGlu || gatedActType == MoE::GatedActType::GeGlu) { - ActType actType = - (gatedActType == MoE::GatedActType::SwiGlu) ? ActType::SwiGlu : ActType::GeGlu; + int64_t actTypeInt = static_cast(activationType); + FLASHINFER_CHECK( + 0 <= actTypeInt && actTypeInt < static_cast(ActivationType::InvalidType), + "Unknown activation type", serializeActivationType(activationType), "of enum", actTypeInt); + bool isGatedAct = isGatedActivation(activationType); + if (isGatedAct) { + ActType actType = activationTypeToGatedActType(activationType); tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions options = { // Swap A and B dtypes because transposeMmaOutput is hardcoded to true .dtypeA = dtypeWeights, @@ -209,24 +245,40 @@ tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions getOptions( .transposeMmaOutput = true, .tileSize = tileTokensDim, .epilogueTileM = useDeepSeekFp8 ? 64 : 128, - .useShuffledMatrixA = useShuffledMatrixA, + .useShuffledMatrix = useShuffledMatrix, .weightLayout = weightLayout}; return options; } else { - FLASHINFER_CHECK(false, "Unimplemented gated act type ", - MoE::serializeGatedActType(gatedActType), " of enum ", (int)gatedActType); + EltwiseActType actType = activationTypeToEltwiseActType(activationType); + tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions options = { + // Swap A and B dtypes because transposeMmaOutput is hardcoded to true + .dtypeA = dtypeWeights, + .dtypeB = dtypeAct, + .dtypeC = dtypeAct, + .eltwiseActType = actType, + .deepSeekFp8 = useDeepSeekFp8, + .fusedAct = false, + .routeAct = true, + .staticBatch = false, + .transposeMmaOutput = true, + .tileSize = tileTokensDim, + .epilogueTileM = 128, + .useShuffledMatrix = useShuffledMatrix, + .weightLayout = weightLayout}; + return options; } } Runner::Runner(btg::Dtype dtypeAct, btg::Dtype dtypeWeights, bool useDeepSeekFp8, int tileTokensDim, - MoE::GatedActType gatedActType, bool useShuffledMatrixA, + ActivationType activationType, bool useShuffledMatrix, batchedGemm::gemm::MatrixLayout weightLayout) : mDtypeAct(dtypeAct), mDtypeWeights(dtypeWeights), mTileTokensDim(tileTokensDim), mRunner(tensorrt_llm::kernels::TrtllmGenBatchedGemmRunner( - getOptions(mDtypeAct, mDtypeWeights, mTileTokensDim, useDeepSeekFp8, gatedActType, - useShuffledMatrixA, weightLayout))) {} + getOptions(mDtypeAct, mDtypeWeights, mTileTokensDim, useDeepSeekFp8, activationType, + useShuffledMatrix, weightLayout))), + mActType(activationType) {} void Runner::run(void* hiddenState, void* hiddenStateScale, void* weights, void* weightsScale, void* expertWeights, float* outputScalesScalar, float* outputScalesGateScalar, @@ -239,12 +291,14 @@ void Runner::run(void* hiddenState, void* hiddenStateScale, void* weights, void* int device, cudaStream_t stream, int32_t configIndex, bool enable_pdl) { auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); - mRunner.run(numTokens, 2 * intermediateSize, hiddenSize, {}, numTokens, numExperts, - maxNumCtasInBatchDim, hiddenState, hiddenStateScale, weights, weightsScale, - expertWeights, /* perTokensSfB */ nullptr, outputScalesScalar, outputScalesGateScalar, - ptrBias, ptrAlpha, ptrBeta, ptrClampLimit, output, outputScale, permutedIdxToTokenIdx, - ptrTotalNumPaddedTokens, ptrCtaIdxXyToBatchIdx, ptrCtaIdxXyToMnLimit, - ptrNumNonExitingCtas, bmm1Workspace, stream, device, configIndex, enable_pdl); + int32_t intermediateSizeFactor = (isGatedActivation(mActType) ? 2 : 1); + mRunner.run(numTokens, intermediateSizeFactor * intermediateSize, hiddenSize, {}, numTokens, + numExperts, maxNumCtasInBatchDim, hiddenState, hiddenStateScale, weights, + weightsScale, expertWeights, /* perTokensSfB */ nullptr, outputScalesScalar, + outputScalesGateScalar, ptrBias, ptrAlpha, ptrBeta, ptrClampLimit, output, + outputScale, permutedIdxToTokenIdx, ptrTotalNumPaddedTokens, ptrCtaIdxXyToBatchIdx, + ptrCtaIdxXyToMnLimit, ptrNumNonExitingCtas, bmm1Workspace, stream, device, + configIndex, enable_pdl); } size_t Runner::getWorkspaceSizeInBytes(int32_t topK, int32_t hiddenSize, int32_t intermediateSize, @@ -252,8 +306,10 @@ size_t Runner::getWorkspaceSizeInBytes(int32_t topK, int32_t hiddenSize, int32_t int32_t configIndex) const { auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); - return mRunner.getWorkspaceSizeInBytes(numTokens, 2 * intermediateSize, hiddenSize, {}, numTokens, - numExperts, maxNumCtasInBatchDim, configIndex); + int32_t intermediateSizeFactor = (isGatedActivation(mActType) ? 2 : 1); + return mRunner.getWorkspaceSizeInBytes(numTokens, intermediateSizeFactor * intermediateSize, + hiddenSize, {}, numTokens, numExperts, + maxNumCtasInBatchDim, configIndex); } int32_t Runner::getDefaultValidConfigIndex(int32_t topK, int32_t hiddenSize, @@ -261,8 +317,10 @@ int32_t Runner::getDefaultValidConfigIndex(int32_t topK, int32_t hiddenSize, int32_t numTokens) const { auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); - return mRunner.getDefaultValidConfigIndex(numTokens, 2 * intermediateSize, hiddenSize, {}, - numTokens, numExperts, maxNumCtasInBatchDim); + int32_t intermediateSizeFactor = (isGatedActivation(mActType) ? 2 : 1); + return mRunner.getDefaultValidConfigIndex(numTokens, intermediateSizeFactor * intermediateSize, + hiddenSize, {}, numTokens, numExperts, + maxNumCtasInBatchDim); } bool Runner::isValidConfigIndex(int32_t configIndex, int32_t topK, int32_t hiddenSize, @@ -271,9 +329,10 @@ bool Runner::isValidConfigIndex(int32_t configIndex, int32_t topK, int32_t hidde auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); + int32_t intermediateSizeFactor = (isGatedActivation(mActType) ? 2 : 1); auto const isValid = - mRunner.isValidConfigIndex(configIndex, numTokens, 2 * intermediateSize, hiddenSize, {}, - numTokens, numExperts, maxNumCtasInBatchDim); + mRunner.isValidConfigIndex(configIndex, numTokens, intermediateSizeFactor * intermediateSize, + hiddenSize, {}, numTokens, numExperts, maxNumCtasInBatchDim); return isValid; } @@ -286,12 +345,13 @@ std::vector Runner::getPassingConfigIndices() const { namespace Gemm2 { tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions getOptions( btg::Dtype dtypeAct, btg::Dtype dtypeWeights, btg::Dtype dtypeOut, int32_t tileTokensDim, - bool useDeepSeekFp8, bool useShuffledMatrixA, batchedGemm::gemm::MatrixLayout weightLayout) { + bool useDeepSeekFp8, bool useShuffledMatrix, batchedGemm::gemm::MatrixLayout weightLayout) { tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions options = { // Swap A and B dtypes because transposeMmaOutput is hardcoded to true .dtypeA = dtypeWeights, .dtypeB = dtypeAct, .dtypeC = dtypeOut, + .eltwiseActType = EltwiseActType::None, .deepSeekFp8 = useDeepSeekFp8, .fusedAct = false, .routeAct = false, @@ -299,13 +359,13 @@ tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions getOptions( .transposeMmaOutput = true, .tileSize = tileTokensDim, .epilogueTileM = useDeepSeekFp8 ? 64 : 128, - .useShuffledMatrixA = useShuffledMatrixA, + .useShuffledMatrix = useShuffledMatrix, .weightLayout = weightLayout}; return options; } Runner::Runner(btg::Dtype dtypeAct, btg::Dtype dtypeWeights, btg::Dtype dtypeOut, - bool useDeepSeekFp8, int tileTokensDim, bool useShuffledMatrixA, + bool useDeepSeekFp8, int tileTokensDim, bool useShuffledMatrix, batchedGemm::gemm::MatrixLayout weightLayout) : mDtypeAct(dtypeAct), mDtypeWeights(dtypeWeights), @@ -313,7 +373,7 @@ Runner::Runner(btg::Dtype dtypeAct, btg::Dtype dtypeWeights, btg::Dtype dtypeOut mTileTokensDim(tileTokensDim), mRunner(tensorrt_llm::kernels::TrtllmGenBatchedGemmRunner( getOptions(dtypeAct, dtypeWeights, dtypeOut, tileTokensDim, useDeepSeekFp8, - useShuffledMatrixA, weightLayout))) {} + useShuffledMatrix, weightLayout))) {} void Runner::run(void* permutedHiddenState, void* permutedHiddenStateScale, void* weights, void* weightsScale, float* outputScalesScalar, float* ptrBias, void* output, @@ -373,12 +433,12 @@ std::vector Runner::getPassingConfigIndices() const { namespace MoE { Runner::Runner(btg::Dtype dtypeAct, btg::Dtype dtypeWeights, bool useDeepSeekFp8, - int32_t tileTokensDim, GatedActType gatedActType, bool useShuffledMatrixA, + int32_t tileTokensDim, ActivationType activationType, bool useShuffledMatrix, batchedGemm::gemm::MatrixLayout weightLayout) : mPermuteGemm1(PermuteGemm1::Runner(dtypeAct, dtypeWeights, useDeepSeekFp8, tileTokensDim, - gatedActType, useShuffledMatrixA, weightLayout)), + activationType, useShuffledMatrix, weightLayout)), mGemm2(Gemm2::Runner(dtypeAct, dtypeWeights, btg::Dtype::Bfloat16, useDeepSeekFp8, - tileTokensDim, useShuffledMatrixA, weightLayout)) { + tileTokensDim, useShuffledMatrix, weightLayout)) { auto const& gemm1PassingIndices = mPermuteGemm1.getPassingConfigIndices(); auto const& gemm2PassingIndices = mGemm2.getPassingConfigIndices(); @@ -395,9 +455,9 @@ Runner::Runner(btg::Dtype dtypeAct, btg::Dtype dtypeWeights, bool useDeepSeekFp8 } Runner::Runner(btg::Dtype dtypeElt, bool useDeepSeekFp8, int32_t tileTokensDim, - bool useShuffledMatrixA, batchedGemm::gemm::MatrixLayout weightLayout) - : Runner(dtypeElt, dtypeElt, useDeepSeekFp8, tileTokensDim, GatedActType::SwiGlu, - useShuffledMatrixA, weightLayout) {} + bool useShuffledMatrix, batchedGemm::gemm::MatrixLayout weightLayout) + : Runner(dtypeElt, dtypeElt, useDeepSeekFp8, tileTokensDim, ActivationType::Swiglu, + useShuffledMatrix, weightLayout) {} void Runner::setOpsData(MoERunnerArgs const& args, MoEWorkspace const& workspace, moe::dev::convertsf::Data& convertSfData, @@ -420,7 +480,8 @@ void Runner::setOpsData(MoERunnerArgs const& args, MoEWorkspace const& workspace activationData.outPtr = workspace.activation_output; activationData.inDqSfsPtr = workspace.gemm1_output_scale; activationData.outDqSfsPtr = workspace.activation_output_scale; - activationData.innerDim = args.intermediate_size * 2; + activationData.innerDim = + args.intermediate_size * (isGatedActivation(args.activation_type) ? 2 : 1); activationData.topK = args.top_k; activationData.numTokens = args.num_tokens; activationData.expandedIdxToPermutedIdx = workspace.expanded_idx_to_permuted_idx; diff --git a/flashinfer/__init__.py b/flashinfer/__init__.py index c22b4a0a55..c78ceb215b 100644 --- a/flashinfer/__init__.py +++ b/flashinfer/__init__.py @@ -75,8 +75,8 @@ ) from .fp8_quantization import mxfp8_dequantize_host, mxfp8_quantize from .fused_moe import ( + ActivationType, RoutingMethodType, - GatedActType, cutlass_fused_moe, reorder_rows_for_gated_act_gemm, trtllm_fp4_block_scale_moe, diff --git a/flashinfer/fused_moe/__init__.py b/flashinfer/fused_moe/__init__.py index f7886fe400..a077ea82d5 100644 --- a/flashinfer/fused_moe/__init__.py +++ b/flashinfer/fused_moe/__init__.py @@ -15,8 +15,8 @@ """ from .core import ( + ActivationType, RoutingMethodType, - GatedActType, WeightLayout, convert_to_block_layout, cutlass_fused_moe, @@ -40,8 +40,8 @@ ) __all__ = [ + "ActivationType", "RoutingMethodType", - "GatedActType", "WeightLayout", "convert_to_block_layout", "cutlass_fused_moe", diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 0e9d643b4c..2821ce829a 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -173,15 +173,6 @@ class WeightLayout(IntEnum): BlockMajorK = 2 -# The type of gated activation function -# Please keep this in sync with the counterpart defined in include/flashinfer/trtllm/fused_moe/runner.h -class GatedActType(IntEnum): - # SwiGlu - SwiGlu = 0 - # GeGlu - GeGlu = 1 - - @functools.cache def is_trtllm_moe_supported( dtype_weights: DtypeTrtllmGen, @@ -221,12 +212,16 @@ def _maybe_get_cached_w3_w1_permute_indices( dst_w3_w1_weight: torch.Tensor, epilogue_tile_m: int, num_elts_per_sf: Union[None, int] = None, + is_gated_act_gemm: bool = True, ) -> torch.Tensor: # Create a unique cache key (weight_type, weight_shape) cache_key = ("w3_w1", dst_w3_w1_weight.shape) if cache_key not in _cache_permute_indices: # Get permute indices and chain them together - permute0 = get_reorder_rows_for_gated_act_gemm_row_indices(dst_w3_w1_weight) + if is_gated_act_gemm: + permute0 = get_reorder_rows_for_gated_act_gemm_row_indices(dst_w3_w1_weight) + else: + permute0 = torch.arange(dst_w3_w1_weight.shape[0], dtype=torch.long) if num_elts_per_sf is None: permute1 = get_shuffle_matrix_a_row_indices( dst_w3_w1_weight, epilogue_tile_m=epilogue_tile_m @@ -994,7 +989,7 @@ def __init__( use_deepseek_fp8: bool, hidden_size: int, intermediate_size: int, - gated_act_type: int = GatedActType.SwiGlu, + activation_type: int = ActivationType.Swiglu, use_shuffled_weight: bool = False, weight_layout: int = WeightLayout.MajorK, use_packed_weights: bool = False, @@ -1007,7 +1002,7 @@ def __init__( self.top_k = top_k self.hidden_size = hidden_size self.intermediate_size = intermediate_size - self.gated_act_type = GatedActType(gated_act_type) + self.activation_type = ActivationType(activation_type) self.use_shuffled_weight = use_shuffled_weight self.weight_layout = WeightLayout(weight_layout) self.use_packed_weights = use_packed_weights @@ -1035,7 +1030,7 @@ def get_valid_tactics( self.hidden_size, self.intermediate_size, self.num_local_experts, - self.gated_act_type, + self.activation_type, self.use_shuffled_weight, self.weight_layout, num_tokens, @@ -1179,6 +1174,7 @@ def forward( kwargs["routing_method_type"], kwargs["enable_pdl"], [-1, -1] if tactic == -1 else tactic, + self.activation_type, ) elif ( self.dtype_act == DtypeTrtllmGen.Bfloat16 @@ -1239,7 +1235,7 @@ def forward( kwargs["routing_method_type"], kwargs["enable_pdl"], kwargs["do_finalize"], - self.gated_act_type, + self.activation_type, output, [-1, -1] if tactic == -1 else tactic, ) @@ -1328,7 +1324,7 @@ def trtllm_bf16_moe_op( intermediate_size=intermediate_size, weight_layout=weight_layout, use_shuffled_weight=use_shuffled_weight, - gated_act_type=GatedActType.SwiGlu, # Default for BF16 + activation_type=ActivationType.Swiglu, # Default for BF16 ) inputs = [output, routing_logits, topk_ids, expert_weights, hidden_states] @@ -1426,6 +1422,7 @@ def trtllm_fp8_per_tensor_scale_moe_op( routing_method_type: int = 0, enable_pdl: Optional[bool] = None, tune_max_num_tokens: int = 8192, + activation_type: ActivationType = ActivationType.Swiglu, ) -> torch.Tensor: if enable_pdl is None: enable_pdl = device_support_pdl(hidden_states.device) @@ -1460,6 +1457,7 @@ def trtllm_fp8_per_tensor_scale_moe_op( intermediate_size=intermediate_size, weight_layout=WeightLayout.MajorK, use_shuffled_weight=True, + activation_type=activation_type, ) inputs = [output, routing_logits, topk_ids, expert_weights, hidden_states] @@ -1484,6 +1482,7 @@ def trtllm_fp8_per_tensor_scale_moe_op( use_routing_scales_on_input=use_routing_scales_on_input, routing_method_type=routing_method_type, enable_pdl=enable_pdl, + activation_type=activation_type.value, ) # Call the C++ function result = moe_op.trtllm_fp8_per_tensor_scale_moe( @@ -1508,6 +1507,7 @@ def trtllm_fp8_per_tensor_scale_moe_op( routing_method_type, enable_pdl, [-1, -1] if tactic == -1 else tactic, + activation_type.value, ) return result @@ -1533,6 +1533,7 @@ def _fake_trtllm_fp8_per_tensor_scale_moe( use_routing_scales_on_input: bool, routing_method_type: int = 0, enable_pdl: Optional[bool] = None, + activation_type: int = ActivationType.Swiglu.value, ): seq_len = hidden_states.shape[0] hidden_size = hidden_states.shape[1] @@ -1751,7 +1752,7 @@ def trtllm_fp4_block_scale_moe_op( routing_method_type: int, do_finalize: bool, enable_pdl: Optional[bool] = None, - gated_act_type: int = 0, + activation_type: int = ActivationType.Swiglu.value, output: Optional[torch.Tensor] = None, tune_max_num_tokens: int = 8192, ) -> List[torch.Tensor]: @@ -1811,7 +1812,7 @@ def trtllm_fp4_block_scale_moe_op( use_deepseek_fp8=False, hidden_size=hidden_size, intermediate_size=intermediate_size, - gated_act_type=gated_act_type, + activation_type=activation_type, weight_layout=WeightLayout.MajorK, use_shuffled_weight=True, ) @@ -1858,7 +1859,7 @@ def trtllm_fp4_block_scale_moe_op( routing_method_type=routing_method_type, enable_pdl=enable_pdl, do_finalize=do_finalize, - gated_act_type=gated_act_type, + activation_type=activation_type, ) # Call the C++ function for block scale MoE @@ -1892,7 +1893,7 @@ def trtllm_fp4_block_scale_moe_op( routing_method_type, do_finalize, enable_pdl, - gated_act_type, + activation_type, output, [-1, -1] if tactic == -1 else tactic, ) @@ -1937,7 +1938,7 @@ def _fake_trtllm_fp4_block_scale_moe( routing_method_type: int, do_finalize: bool, enable_pdl: bool, - gated_act_type: int, + activation_type: int, output: Optional[torch.Tensor], tune_max_num_tokens: int, ): @@ -2009,7 +2010,7 @@ def trtllm_mxint4_block_scale_moe_op( use_deepseek_fp8=False, hidden_size=hidden_size, intermediate_size=intermediate_size, - gated_act_type=GatedActType.SwiGlu, + activation_type=ActivationType.Swiglu, weight_layout=WeightLayout.BlockMajorK, use_shuffled_weight=True, ) @@ -2216,6 +2217,7 @@ def trtllm_fp8_per_tensor_scale_moe( routing_method_type: int = 0, enable_pdl: Optional[bool] = None, tune_max_num_tokens: int = 8192, + activation_type: int = ActivationType.Swiglu.value, ) -> torch.Tensor: """FP8 per tensor scale MoE operation. @@ -2240,6 +2242,15 @@ def trtllm_fp8_per_tensor_scale_moe( routing_method_type: Type of routing method to use (default: 0) enable_pdl: Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90. tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 8192) + activation_type (int): Type of activation function (default: 3 - Swiglu) + - 0: Gelu + - 1: Relu + - 2: Silu + - 3: Swiglu + - 4: Geglu + - 5: SwigluBias + - 6: Relu2 + - 7: Identity Returns: torch.Tensor: Output tensor of shape [seq_len, hidden_size] @@ -2265,6 +2276,7 @@ def trtllm_fp8_per_tensor_scale_moe( routing_method_type, enable_pdl, tune_max_num_tokens, + activation_type, ) @@ -2466,7 +2478,7 @@ def trtllm_fp4_block_scale_moe( routing_method_type: int = 0, do_finalize: bool = True, enable_pdl: Optional[bool] = None, - gated_act_type: int = 0, + activation_type: int = ActivationType.Swiglu.value, output: Optional[torch.Tensor] = None, tune_max_num_tokens: int = 8192, ) -> List[torch.Tensor]: @@ -2521,9 +2533,15 @@ def trtllm_fp4_block_scale_moe( - 4: RenormalizeNaive (Softmax -> TopK -> Renormalize) do_finalize (bool): Whether to finalize the output (default: False) enable_pdl (Optional[bool]): Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90. - gated_act_type (int): Type of gated activation function (default: 0) - - 0: SwiGlu - - 1: GeGlu + activation_type (int): Type of activation function (default: 3 - Swiglu) + - 0: Gelu + - 1: Relu + - 2: Silu + - 3: Swiglu + - 4: Geglu + - 5: SwigluBias + - 6: Relu2 + - 7: Identity tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 8192) output (Optional[torch.Tensor]): shape [seq_len, hidden_size] Optional inplace output tensor. @@ -2561,7 +2579,7 @@ def trtllm_fp4_block_scale_moe( routing_method_type, do_finalize, enable_pdl, - gated_act_type, + activation_type, output, tune_max_num_tokens, ) @@ -2596,7 +2614,7 @@ def trtllm_fp4_block_scale_routed_moe( routing_method_type: int = 0, do_finalize: bool = True, enable_pdl: Optional[bool] = None, - gated_act_type: int = 0, + activation_type: int = ActivationType.Swiglu.value, output: Optional[torch.Tensor] = None, tune_max_num_tokens: int = 8192, ) -> List[torch.Tensor]: @@ -2652,9 +2670,15 @@ def trtllm_fp4_block_scale_routed_moe( - 3: Llama4 (Top1 -> Sigmoid) - 4: RenormalizeNaive (Softmax -> TopK -> Renormalize) do_finalize (bool): Whether to finalize the output (default: False) - gated_act_type (int): Type of gated activation function (default: 0) - - 0: SwiGlu - - 1: GeGlu + activation_type (int): Type of activation function (default: 3 - Swiglu) + - 0: Gelu + - 1: Relu + - 2: Silu + - 3: Swiglu + - 4: Geglu + - 5: SwigluBias + - 6: Relu2 + - 7: Identity tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 8192) output (Optional[torch.Tensor]): shape [seq_len, hidden_size] Optional inplace output tensor. @@ -2693,7 +2717,7 @@ def trtllm_fp4_block_scale_routed_moe( routing_method_type, do_finalize, enable_pdl, - gated_act_type, + activation_type, output, tune_max_num_tokens, ) diff --git a/include/flashinfer/trtllm/batched_gemm/KernelRunner.h b/include/flashinfer/trtllm/batched_gemm/KernelRunner.h index 970f1ae494..54cd824c0e 100644 --- a/include/flashinfer/trtllm/batched_gemm/KernelRunner.h +++ b/include/flashinfer/trtllm/batched_gemm/KernelRunner.h @@ -47,11 +47,27 @@ enum class ActType { GeGlu, }; +// Type of the element-wise activation to apply after the Gemm +enum class EltwiseActType { + None = 0, + // Gelu is defined as the following operation: + // act = x0 * phi(x0) + // where x0 is the output of the Gemm + // phi is the CDF of standard normal distribution approximated by + // phi(x) = 0.5 * (1 + tanh(0.7978845608028654 * (x + 0.044715 * x * x * x))) + Gelu, + // Relu2 (also known as squared Relu) is defined as the following operation: + // act = relu(x0) ^ 2 + // where x0 is the output of the Gemm. + Relu2, +}; + struct TrtllmGenBatchedGemmRunnerOptions { batchedGemm::trtllm::gen::Dtype dtypeA; batchedGemm::trtllm::gen::Dtype dtypeB; batchedGemm::trtllm::gen::Dtype dtypeC; ActType actType{ActType::SwiGlu}; + EltwiseActType eltwiseActType{EltwiseActType::None}; bool deepSeekFp8{false}; bool fusedAct{false}; bool routeAct{false}; @@ -59,7 +75,7 @@ struct TrtllmGenBatchedGemmRunnerOptions { bool transposeMmaOutput{false}; int32_t tileSize{8}; int32_t epilogueTileM{128}; - bool useShuffledMatrixA{false}; + bool useShuffledMatrix{false}; batchedGemm::gemm::MatrixLayout weightLayout{batchedGemm::gemm::MatrixLayout::MajorK}; }; diff --git a/include/flashinfer/trtllm/fused_moe/DevKernel.h b/include/flashinfer/trtllm/fused_moe/DevKernel.h index 23abb87a7b..560063c023 100644 --- a/include/flashinfer/trtllm/fused_moe/DevKernel.h +++ b/include/flashinfer/trtllm/fused_moe/DevKernel.h @@ -169,56 +169,65 @@ namespace moe::dev { FLASHINFER_WARN("Unsupported dtypeExpW"); \ } -#define LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, \ - smemSize, stream, extraFlag, numExperts) \ - if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Fp32 && \ - data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, float, numExperts, extraFlag), kernel, \ - numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Fp32 && \ - data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, __nv_bfloat16, numExperts, extraFlag), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ - data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, float, numExperts, extraFlag), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ - data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_TILEN(data, coopLaunch, \ - LAUNCH_ESC(float, __nv_bfloat16, __nv_bfloat16, numExperts, extraFlag), kernel, \ - numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Fp32 && \ - data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, float, float, numExperts, extraFlag), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Fp32 && \ - data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_TILEN(data, coopLaunch, \ - LAUNCH_ESC(__nv_bfloat16, float, __nv_bfloat16, numExperts, extraFlag), kernel, \ - numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ - data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_TILEN(data, coopLaunch, \ - LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, float, numExperts, extraFlag), kernel, \ - numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ - data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_TILEN(data, coopLaunch, \ - LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, numExperts, extraFlag), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } else { \ - FLASHINFER_WARN("Unsupported dtypeExpW"); \ +#define LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, \ + smemSize, stream, extraFlag, numExperts, \ + numTopExperts) \ + if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Fp32 && \ + data.mDtypeExpW == tg::Dtype::Fp32) { \ + LAUNCH_TILEN(data, coopLaunch, \ + LAUNCH_ESC(float, float, float, numExperts, numTopExperts, extraFlag), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Fp32 && \ + data.mDtypeExpW == tg::Dtype::Bfloat16) { \ + LAUNCH_TILEN(data, coopLaunch, \ + LAUNCH_ESC(float, float, __nv_bfloat16, numExperts, numTopExperts, extraFlag), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ + data.mDtypeExpW == tg::Dtype::Fp32) { \ + LAUNCH_TILEN(data, coopLaunch, \ + LAUNCH_ESC(float, __nv_bfloat16, float, numExperts, numTopExperts, extraFlag), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ + data.mDtypeExpW == tg::Dtype::Bfloat16) { \ + LAUNCH_TILEN( \ + data, coopLaunch, \ + LAUNCH_ESC(float, __nv_bfloat16, __nv_bfloat16, numExperts, numTopExperts, extraFlag), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Fp32 && \ + data.mDtypeExpW == tg::Dtype::Fp32) { \ + LAUNCH_TILEN(data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, float, float, numExperts, numTopExperts, extraFlag), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Fp32 && \ + data.mDtypeExpW == tg::Dtype::Bfloat16) { \ + LAUNCH_TILEN( \ + data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, float, __nv_bfloat16, numExperts, numTopExperts, extraFlag), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ + data.mDtypeExpW == tg::Dtype::Fp32) { \ + LAUNCH_TILEN( \ + data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, float, numExperts, numTopExperts, extraFlag), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ + data.mDtypeExpW == tg::Dtype::Bfloat16) { \ + LAUNCH_TILEN(data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, numExperts, \ + numTopExperts, extraFlag), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } else { \ + FLASHINFER_WARN("Unsupported dtypeExpW"); \ } -#define LAUNCH_ROUTING_DEEPSEEK_IMPL(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag, numExperts) \ - if (extraFlag) { \ - LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, \ - smemSize, stream, true, numExperts); \ - } else { \ - LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, \ - smemSize, stream, false, numExperts); \ +#define LAUNCH_ROUTING_DEEPSEEK_IMPL(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + stream, extraFlag, numExperts, numTopExperts) \ + if (extraFlag) { \ + LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, \ + smemSize, stream, true, numExperts, numTopExperts); \ + } else { \ + LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, \ + smemSize, stream, false, numExperts, numTopExperts); \ } //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/flashinfer/trtllm/fused_moe/RoutingKernel.h b/include/flashinfer/trtllm/fused_moe/RoutingKernel.h index cae6729368..709fb57c0f 100644 --- a/include/flashinfer/trtllm/fused_moe/RoutingKernel.h +++ b/include/flashinfer/trtllm/fused_moe/RoutingKernel.h @@ -176,14 +176,15 @@ struct Data : public DataBase { bool mUseRoutingSoftmax; }; -template +template struct KernelParams : public KernelParamsBase { using InputT = InputT_; using BiasT = BiasT_; using OutputT = OutputT_; static constexpr bool UseGroups = UseGroups_; + static constexpr int MaxNumTopExperts = MaxNumTopExperts_; PackedScoreIdx* mPtrTopKPacked = nullptr; diff --git a/include/flashinfer/trtllm/fused_moe/runner.h b/include/flashinfer/trtllm/fused_moe/runner.h index 3941a23249..46617e5dbd 100644 --- a/include/flashinfer/trtllm/fused_moe/runner.h +++ b/include/flashinfer/trtllm/fused_moe/runner.h @@ -136,25 +136,48 @@ class Runner { } // namespace Routing namespace MoE { -// The type of gated activation function +// The type of activation function // Please keep this in sync with the counterpart defined in flashinfer/flashinfer/fused_moe/core.py -enum class GatedActType : int64_t { - // SwiGlu - SwiGlu = 0, - // GeGlu - GeGlu = 1, +enum class ActivationType : int64_t { + Gelu = 0, + Relu = 1, + Silu = 2, + Swiglu = 3, + Geglu = 4, + SwigluBias = 5, + Relu2 = 6, + Identity = 7, + InvalidType = 8, // Must be last }; -inline std::string serializeGatedActType(GatedActType gatedActType) { - switch (gatedActType) { - case GatedActType::SwiGlu: - return "SwiGlu"; - case GatedActType::GeGlu: - return "GeGlu"; +inline std::string serializeActivationType(ActivationType activationType) { + switch (activationType) { + case ActivationType::Gelu: + return "Gelu"; + case ActivationType::Relu: + return "Relu"; + case ActivationType::Silu: + return "Silu"; + case ActivationType::Swiglu: + return "Swiglu"; + case ActivationType::Geglu: + return "Geglu"; + case ActivationType::SwigluBias: + return "SwigluBias"; + case ActivationType::Relu2: + return "Relu2"; + case ActivationType::Identity: + return "Identity"; default: - return "InvalidGatedActType"; // TODO throw error + return "InvalidActivationType"; // TODO throw error }; } + +inline bool isGatedActivation(ActivationType activationType) { + return activationType == ActivationType::Swiglu || activationType == ActivationType::Geglu || + activationType == ActivationType::SwigluBias; +} + } // namespace MoE namespace PermuteGemm1 { @@ -162,7 +185,7 @@ class Runner { public: explicit Runner(batchedGemm::trtllm::gen::Dtype dtypeAct, batchedGemm::trtllm::gen::Dtype dtypeWeights, bool useDeepSeekFp8, - int tileTokensDim, MoE::GatedActType gatedActType, bool useShuffledMatrixA, + int tileTokensDim, MoE::ActivationType activationType, bool useShuffledMatrix, batchedGemm::gemm::MatrixLayout weight_layout); size_t getWorkspaceSizeInBytes(int32_t topK, int32_t hiddenSize, int32_t intermediateSize, @@ -193,6 +216,7 @@ class Runner { batchedGemm::trtllm::gen::Dtype mDtypeWeights; int32_t mTileTokensDim; tensorrt_llm::kernels::TrtllmGenBatchedGemmRunner mRunner; + tensorrt_llm::kernels::trtllmgen_moe::MoE::ActivationType mActType; }; } // namespace PermuteGemm1 @@ -202,7 +226,7 @@ class Runner { explicit Runner(batchedGemm::trtllm::gen::Dtype dtypeAct, batchedGemm::trtllm::gen::Dtype dtypeWeights, batchedGemm::trtllm::gen::Dtype outputDtype, bool useDeepSeekFp8, - int tileTokensDim, bool useShuffledMatrixA, + int tileTokensDim, bool useShuffledMatrix, batchedGemm::gemm::MatrixLayout weight_layout); size_t getWorkspaceSizeInBytes(int32_t topK, int32_t hiddenSize, int32_t intermediateSize, @@ -259,6 +283,8 @@ struct MoERunnerArgs { float* gemm1_clamp_limit = nullptr; float* gemm2_bias = nullptr; + ActivationType activation_type = ActivationType::Swiglu; + int32_t num_tokens{0}; int32_t num_experts{0}; // Hidden dimension input of MoE block. It might be padded. @@ -356,10 +382,10 @@ class Runner { // FIXME: tileTokensDim is hardcoded for now Runner(batchedGemm::trtllm::gen::Dtype dtypeAct, batchedGemm::trtllm::gen::Dtype dtypeWeights, bool useDeepSeekFp8, int tileTokensDim = 8, - GatedActType gatedActType = GatedActType::SwiGlu, bool useShuffledMatrixA = false, + ActivationType activationType = ActivationType::Swiglu, bool useShuffledMatrix = false, batchedGemm::gemm::MatrixLayout weight_layout = batchedGemm::gemm::MatrixLayout::MajorK); Runner(batchedGemm::trtllm::gen::Dtype dtypeElt, bool useDeepSeekFp8, int tileTokensDim = 8, - bool useShuffledMatrixA = false, + bool useShuffledMatrix = false, batchedGemm::gemm::MatrixLayout weight_layout = batchedGemm::gemm::MatrixLayout::MajorK); void run(MoERunnerArgs const& args, MoEWorkspace const& workspace, int device, diff --git a/tests/moe/test_dpsk_fused_moe_fp8.py b/tests/moe/test_dpsk_fused_moe_fp8.py index 711e05f234..cd44f2faf2 100644 --- a/tests/moe/test_dpsk_fused_moe_fp8.py +++ b/tests/moe/test_dpsk_fused_moe_fp8.py @@ -8,7 +8,7 @@ trtllm_fp8_block_scale_moe, ) from .utils import skip_checks, QuantMode -from flashinfer import GatedActType +from flashinfer import ActivationType def dequant_fp8_block_scaled( @@ -616,7 +616,7 @@ def __init__(self): moe_impl=moe_impl, routing_config=routing_config, weight_processing=weight_processing, - gated_act_type=GatedActType.SwiGlu, + activation_type=ActivationType.Swiglu, num_tokens=seq_len, hidden_size=7168, # DeepSeek-V3 hidden size intermediate_size=intermediate_size, diff --git a/tests/moe/test_trtllm_gen_fused_moe.py b/tests/moe/test_trtllm_gen_fused_moe.py index 89cbf84d4e..a93767e457 100644 --- a/tests/moe/test_trtllm_gen_fused_moe.py +++ b/tests/moe/test_trtllm_gen_fused_moe.py @@ -22,8 +22,8 @@ from torch.nn import functional as F from flashinfer import ( + ActivationType, RoutingMethodType, - GatedActType, e2m1_and_ufp8sf_scale_to_float, fp4_quantize, mxfp8_dequantize_host, @@ -46,7 +46,7 @@ get_w2_permute_indices_with_cache, _maybe_get_cached_w3_w1_permute_indices, ) -from .utils import skip_checks, QuantMode +from .utils import is_gated_activation, skip_checks, QuantMode # Max num tokens to tune for trtllm-gen fused moe @@ -209,7 +209,7 @@ def _run_moe_computation(self, runtime_args): local_num_experts=self.config["num_experts"], routed_scaling_factor=self.config["routed_scaling"], routing_method_type=self.config["routing_method_type"], - gated_act_type=self.config["gated_act_type"], + activation_type=self.config["activation_type"], do_finalize=True, tune_max_num_tokens=TUNE_MAX_NUM_TOKENS, ) @@ -227,6 +227,12 @@ class Moe(ABC): def __init__(self): self.name = self.__class__.__name__ + @property + @abstractmethod + def quant_mode(self) -> QuantMode: + """Get the quantization mode of this MoE implementation.""" + pass + @abstractmethod def quantize_weights(self, gemm1_weights, gemm2_weights, hidden_states_sample): """Quantize static weights and compute global scale factors (done offline).""" @@ -305,13 +311,17 @@ class FP4Moe(Moe): def __init__(self, quant_mode: QuantMode): super().__init__() - self.quant_mode = quant_mode + self._quant_mode = quant_mode self.is_mxfp4 = ( quant_mode == QuantMode.FP4_MXFP4_MXFP8 or quant_mode == QuantMode.FP4_MXFP4_Bf16 ) self.sf_vec_size = 32 if self.is_mxfp4 else 16 + @property + def quant_mode(self) -> QuantMode: + return self._quant_mode + def quantize_weights(self, gemm1_weights, gemm2_weights, hidden_states_sample): """Quantize weights to FP4 format and compute global scale factors.""" num_experts = gemm1_weights.shape[0] @@ -408,13 +418,16 @@ def prepare_static_weights_for_kernel( ) # Convert quantized weights to proper formats + intermediate_size_factor = 2 if is_gated_activation(args.activation_type) else 1 gemm1_weights_fp4 = args.gemm1_weights.view(torch.float8_e4m3fn).reshape( - num_experts, 2 * intermediate_size, hidden_size // 2 + num_experts, intermediate_size_factor * intermediate_size, hidden_size // 2 ) # packed fp4 gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view( torch.float8_e4m3fn ).reshape( - num_experts, 2 * intermediate_size, hidden_size // self.sf_vec_size + num_experts, + intermediate_size_factor * intermediate_size, + hidden_size // self.sf_vec_size, ) # fp8 scaling factors gemm2_weights_fp4 = args.gemm2_weights.view(torch.float8_e4m3fn).reshape( @@ -440,6 +453,7 @@ def prepare_static_weights_for_kernel( self._cache_permute_indices, gemm1_weights_fp4[i].view(torch.uint8), epilogue_tile_m, + is_gated_act_gemm=is_gated_activation(args.activation_type), ) gemm1_weights_fp4_shuffled.append( gemm1_weights_fp4[i] @@ -452,6 +466,7 @@ def prepare_static_weights_for_kernel( gemm1_scales_linear_fp4[i].view(torch.uint8), epilogue_tile_m, num_elts_per_sf=16, + is_gated_act_gemm=is_gated_activation(args.activation_type), ) gemm1_scales_fp4_shuffled.append( block_scale_interleave( @@ -496,7 +511,9 @@ def prepare_static_weights_for_kernel( torch.stack(gemm1_scales_fp4_shuffled) .view(torch.float8_e4m3fn) .reshape( - num_experts, 2 * intermediate_size, hidden_size // self.sf_vec_size + num_experts, + intermediate_size_factor * intermediate_size, + hidden_size // self.sf_vec_size, ) ) @@ -508,11 +525,16 @@ def prepare_static_weights_for_kernel( ) # Calculate scaling factors that depend on weights - scale_c_fc1 = ( - args_dequant.c_global_sf - * (1.0 / args.gemm1_scales_global) - * (1.0 / args.hidden_states_scale_global) - ) + if is_gated_activation(args.activation_type): + scale_c_fc1 = ( + args_dequant.c_global_sf + * (1.0 / args.gemm1_scales_global) + * (1.0 / args.hidden_states_scale_global) + ) + else: + scale_c_fc1 = torch.full_like( + args.gemm1_scales_global, args_dequant.c_global_sf + ) scale_gate_fc1 = (1.0 / args.gemm1_scales_global) * ( 1.0 / args.hidden_states_scale_global ) @@ -543,7 +565,7 @@ def call_moe( top_k_groups = kwargs["top_k_groups"] intermediate_size = kwargs["intermediate_size"] routed_scaling = kwargs["routed_scaling"] - gated_act_type = kwargs["gated_act_type"] + activation_type = kwargs["activation_type"] routing_method_type = kwargs["routing_method_type"] enable_autotune = kwargs.get("enable_autotune", True) @@ -556,7 +578,7 @@ def call_moe( "top_k_groups": top_k_groups, "intermediate_size": intermediate_size, "routed_scaling": routed_scaling, - "gated_act_type": gated_act_type, + "activation_type": activation_type, "routing_method_type": routing_method_type, "enable_autotune": enable_autotune, } @@ -610,6 +632,10 @@ def mxint4_quantize( class MxInt4BlockScaleMoe(Moe): """MxInt4 MoE implementation with block scaling (DeepSeek style).""" + @property + def quant_mode(self) -> QuantMode: + return QuantMode.MXINT4_BF16_BF16 + def quantize_weights(self, gemm1_weights, gemm2_weights, hidden_states_sample): """Quantize weights to MxInt4 with block scaling.""" num_experts = gemm1_weights.shape[0] @@ -804,6 +830,10 @@ def get_tolerances(self): class FP8BlockScaleMoe(Moe): """FP8 MoE implementation with block scaling (DeepSeek style).""" + @property + def quant_mode(self) -> QuantMode: + return QuantMode.FP8_BLOCK_SCALE + def quantize_weights(self, gemm1_weights, gemm2_weights, hidden_states_sample): """Quantize weights to FP8 with block scaling.""" num_experts = gemm1_weights.shape[0] @@ -1025,6 +1055,10 @@ def get_tolerances(self): class FP8PerTensorMoe(Moe): """FP8 MoE implementation with per-tensor scaling (Llama4 style).""" + @property + def quant_mode(self) -> QuantMode: + return QuantMode.FP8_PER_TENSOR + def quantize_weights(self, gemm1_weights, gemm2_weights, hidden_states_sample): """Quantize weights to FP8 per-tensor and compute global scale factors.""" # Compute global scale factor for hidden states (offline calibration) @@ -1080,14 +1114,20 @@ def prepare_static_weights_for_kernel( # Reorder rows of W1 for fused gated activation gemm1_weights_fp8_interleaved = [] for i in range(num_experts): - gemm1_weights_fp8_interleaved.append( - reorder_rows_for_gated_act_gemm(args.gemm1_weights[i].clone()) - ) + if is_gated_activation(args.activation_type): + weights = reorder_rows_for_gated_act_gemm(args.gemm1_weights[i].clone()) + else: + weights = args.gemm1_weights[i].clone() + gemm1_weights_fp8_interleaved.append(weights) # Stack weights and scales for all experts gemm1_weights_fp8_interleaved = torch.stack( gemm1_weights_fp8_interleaved - ).reshape(num_experts, 2 * intermediate_size, hidden_size) + ).reshape( + num_experts, + (2 if is_gated_activation(args.activation_type) else 1) * intermediate_size, + hidden_size, + ) # Shuffle weights and scaling factors for transposed mma output gemm1_weights_fp8_shuffled = [] @@ -1114,11 +1154,16 @@ def prepare_static_weights_for_kernel( ) # Calculate scaling factors that depend on weights - scale_c_fc1 = ( - args_dequant.c_global_sf - * (1.0 / args.gemm1_scales_global) - * (1.0 / args.hidden_states_scale_global) - ) + if is_gated_activation(args.activation_type): + scale_c_fc1 = ( + args_dequant.c_global_sf + * (1.0 / args.gemm1_scales_global) + * (1.0 / args.hidden_states_scale_global) + ) + else: + scale_c_fc1 = torch.full_like( + args.gemm1_scales_global, args_dequant.c_global_sf + ) scale_gate_fc1 = (1.0 / args.gemm1_scales_global) * ( 1.0 / args.hidden_states_scale_global ) @@ -1148,6 +1193,7 @@ def call_moe( routed_scaling = kwargs["routed_scaling"] routing_method_type = kwargs["routing_method_type"] enable_autotune = kwargs.get("enable_autotune", True) + activation_type = kwargs["activation_type"] # Quantize to FP8 per-tensor using pre-computed global scale factor hidden_states_fp8, _ = quant_fp8_per_tensor( @@ -1181,6 +1227,7 @@ def call_moe( == RoutingMethodType.Llama4, # Use_routing_scales_on_input routing_method_type, tune_max_num_tokens=TUNE_MAX_NUM_TOKENS, + activation_type=activation_type, ) return output.to(torch.float) @@ -1202,6 +1249,10 @@ def get_tolerances(self): class BF16Moe(Moe): """BF16 MoE implementation.""" + @property + def quant_mode(self) -> QuantMode: + return QuantMode.BF16 + def quantize_weights(self, gemm1_weights, gemm2_weights, hidden_states_sample): """No scaling for weights.""" return { @@ -1383,7 +1434,7 @@ def __init__( gemm2_scales_global, permute_info, use_routing_scales_on_input, - gated_act_type, + activation_type, ): self.num_tokens = num_tokens self.num_experts = num_experts @@ -1403,7 +1454,7 @@ def __init__( self.gemm2_scales_global = gemm2_scales_global self.permute_info = permute_info self.use_routing_scales_on_input = use_routing_scales_on_input - self.gated_act_type = gated_act_type + self.activation_type = activation_type class moe_args_dequant: @@ -1423,7 +1474,7 @@ def __init__( gemm2_weights, permute_info, use_routing_scales_on_input, - gated_act_type, + activation_type, hidden_states_scale=None, ): self.num_tokens = num_tokens @@ -1438,7 +1489,7 @@ def __init__( self.gemm2_weights = gemm2_weights self.permute_info = permute_info self.use_routing_scales_on_input = use_routing_scales_on_input - self.gated_act_type = gated_act_type + self.activation_type = activation_type self.hidden_states_scale = hidden_states_scale @@ -1862,7 +1913,11 @@ def run_moe_dequant(args, quant_mode: QuantMode): # Gemm1 gemm1_output = torch.full( - (total_num_padded_tokens, 2 * args.intermediate_size), + ( + total_num_padded_tokens, + (2 if is_gated_activation(args.activation_type) else 1) + * args.intermediate_size, + ), float("nan"), device="cuda", ).to(torch.float) @@ -1897,12 +1952,13 @@ def run_moe_dequant(args, quant_mode: QuantMode): (total_num_padded_tokens, args.intermediate_size), float("nan"), device="cuda" ).to(torch.float) - gated_act_type = args.gated_act_type - gated_act_type_to_func = { - 0: F.silu, - 1: F.gelu, + activation_type = args.activation_type + activation_type_to_func = { + ActivationType.Swiglu: F.silu, + ActivationType.Geglu: F.gelu, + ActivationType.Relu2: lambda x: F.relu(x) ** 2, } - gated_act_func = gated_act_type_to_func[gated_act_type] + activation_func = activation_type_to_func[activation_type] i = 0 for expert_idx in range(args.num_experts): @@ -1910,9 +1966,13 @@ def run_moe_dequant(args, quant_mode: QuantMode): if my_num_tokens == 0: continue my_a = gemm1_output[i : i + my_num_tokens] - my_x1 = my_a[:, : args.intermediate_size] - my_x2 = my_a[:, args.intermediate_size :] - activation_output[i : i + my_num_tokens] = gated_act_func(my_x2) * my_x1 + if is_gated_activation(args.activation_type): + my_x1 = my_a[:, : args.intermediate_size] + my_x2 = my_a[:, args.intermediate_size :] + activation_output[i : i + my_num_tokens] = activation_func(my_x2) * my_x1 + else: + my_x1 = my_a[:, : args.intermediate_size] + activation_output[i : i + my_num_tokens] = activation_func(my_x1) i += my_num_tokens i = (i + args.padding - 1) // args.padding * args.padding @@ -2039,7 +2099,7 @@ def run_moe_reference_fp4(args, quant_mode: QuantMode): gemm2_weights_dequant, args.permute_info, args.use_routing_scales_on_input, - args.gated_act_type, + args.activation_type, ) return run_moe_dequant(args_dequant, quant_mode), args_dequant @@ -2104,7 +2164,7 @@ def dequant_reference_dsfp8(input, scale, transpose_scale, block_m, block_n): gemm2_weights_dequant, args.permute_info, args.use_routing_scales_on_input, - GatedActType.SwiGlu.value, # gated_act_type + args.activation_type, ) return run_moe_dequant(args_dequant, QuantMode.FP8_BLOCK_SCALE), args_dequant @@ -2141,7 +2201,7 @@ def run_moe_reference_per_tensor_scale_fp8(args): gemm2_weights_dequant, args.permute_info, args.use_routing_scales_on_input, - GatedActType.SwiGlu.value, # gated_act_type + args.activation_type, ) return run_moe_dequant(args_dequant, QuantMode.FP8_PER_TENSOR), args_dequant @@ -2172,7 +2232,7 @@ def run_moe_reference_bf16(args): gemm2_weights_dequant, args.permute_info, args.use_routing_scales_on_input, - GatedActType.SwiGlu.value, # gated_act_type + args.activation_type, ) return run_moe_dequant(args_dequant, QuantMode.BF16), args_dequant @@ -2223,7 +2283,7 @@ def dequantize(weights, scales): gemm2_weights_dequant, args.permute_info, args.use_routing_scales_on_input, - args.gated_act_type, + args.activation_type, ) return run_moe_dequant(args_dequant, QuantMode.MXINT4_BF16_BF16), args_dequant @@ -2257,7 +2317,7 @@ def _compute_moe_actual_unified(moe_impl, args_dequant, args, **kwargs): "routed_scaling": kwargs["routed_scaling"], "routing_method_type": kwargs["routing_method_type"], "do_finalize": True, - "gated_act_type": args.gated_act_type, + "activation_type": args.activation_type, "hidden_states_scale": args.hidden_states_scale, "hidden_states_quant": kwargs["hidden_states_quant"], "enable_autotune": kwargs.get("enable_autotune", True), @@ -2285,7 +2345,7 @@ def run_moe_test( moe_impl, routing_config, weight_processing, - gated_act_type, + activation_type, cache_permute_indices, zero_hidden_states=False, ): @@ -2294,7 +2354,7 @@ def run_moe_test( moe_impl, routing_config, weight_processing, - gated_act_type, + activation_type, num_tokens, hidden_size, intermediate_size, @@ -2319,7 +2379,7 @@ def run_moe_test( # Validation checks assert top_k <= num_experts - assert top_k <= 10 + assert top_k <= 22 if (top_k_groups is not None) and (n_groups is not None) and (n_groups > 0): assert top_k_groups <= 4 assert num_experts > n_groups @@ -2347,7 +2407,11 @@ def run_moe_test( (num_tokens, hidden_size), device="cuda", dtype=torch.bfloat16 ) gemm1_weights = torch.randn( - (num_experts, 2 * intermediate_size, hidden_size), + ( + num_experts, + (2 if is_gated_activation(activation_type) else 1) * intermediate_size, + hidden_size, + ), device="cuda", dtype=torch.bfloat16, ) @@ -2432,7 +2496,7 @@ def run_moe_test( quant_data["gemm2_scales_global"], permute_info, use_routing_scales_on_input, - gated_act_type, + activation_type, ) # Compute reference output @@ -2601,10 +2665,10 @@ def run_moe_test( ], ) @pytest.mark.parametrize( - "gated_act_type", + "activation_type", [ - pytest.param(GatedActType.SwiGlu, id="SwiGlu"), - pytest.param(GatedActType.GeGlu, id="GeGlu"), + pytest.param(ActivationType.Swiglu, id="Swiglu"), + pytest.param(ActivationType.Geglu, id="Geglu"), ], ) def test_renormalize_routing( @@ -2614,7 +2678,7 @@ def test_renormalize_routing( moe_impl, routing_config, weight_processing, - gated_act_type, + activation_type, cache_permute_indices, zero_hidden_states, ): @@ -2626,7 +2690,7 @@ def test_renormalize_routing( moe_impl, routing_config, weight_processing, - gated_act_type, + activation_type, cache_permute_indices, zero_hidden_states=zero_hidden_states, ) @@ -2635,10 +2699,11 @@ def test_renormalize_routing( # Test: DeepSeekV3 routing @pytest.mark.parametrize("num_tokens", [8, 768, 3072]) @pytest.mark.parametrize("hidden_size", [1024]) -@pytest.mark.parametrize("intermediate_size", [2048, 1024, 768, 512, 384]) +@pytest.mark.parametrize("intermediate_size", [2944, 2048, 1024, 768, 512, 384]) @pytest.mark.parametrize( "moe_impl", [ + pytest.param(FP8PerTensorMoe(), id="FP8_PerTensor"), pytest.param(FP8BlockScaleMoe(), id="FP8_Block"), pytest.param(FP4Moe(quant_mode=QuantMode.FP4_NVFP4_NVFP4), id="NvFP4xNvFP4"), pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_MXFP8), id="MxFP4xMxFP8"), @@ -2650,6 +2715,23 @@ def test_renormalize_routing( @pytest.mark.parametrize( "routing_config", [ + pytest.param( + { + "num_experts": 512, + "top_k": 22, + "padding": 8, + "n_groups": 1, + "top_k_groups": 1, + "routed_scaling": 2.5, + "has_routing_bias": True, + "routing_method_type": RoutingMethodType.DeepSeekV3, + "compatible_moe_impls": [FP8PerTensorMoe, FP4Moe], + "compatible_intermediate_size": [2944], + "compatible_activation_types": [ActivationType.Relu2], + "enable_autotune": True, + }, + id="nemotron_3_dummy", + ), pytest.param( { "num_experts": 384, @@ -2662,6 +2744,10 @@ def test_renormalize_routing( "routing_method_type": RoutingMethodType.DeepSeekV3, "compatible_moe_impls": [FP4Moe, FP8BlockScaleMoe], "compatible_intermediate_size": [1024, 2048], + "compatible_activation_types": [ + ActivationType.Swiglu, + ActivationType.Geglu, + ], "enable_autotune": True, }, id="kimi_k2", @@ -2683,6 +2769,10 @@ def test_renormalize_routing( BF16Moe, ], "compatible_intermediate_size": [512, 1024, 2048], + "compatible_activation_types": [ + ActivationType.Swiglu, + ActivationType.Geglu, + ], "enable_autotune": True, }, id="DSv3", @@ -2699,6 +2789,10 @@ def test_renormalize_routing( "routing_method_type": RoutingMethodType.DeepSeekV3, "compatible_moe_impls": [FP4Moe, FP8BlockScaleMoe], "compatible_intermediate_size": [384, 768], + "compatible_activation_types": [ + ActivationType.Swiglu, + ActivationType.Geglu, + ], "enable_autotune": False, }, id="DSLite", @@ -2715,6 +2809,10 @@ def test_renormalize_routing( "routing_method_type": RoutingMethodType.DeepSeekV3, "compatible_moe_impls": [FP4Moe, FP8BlockScaleMoe, BF16Moe], "compatible_intermediate_size": [512, 1024, 1536], + "compatible_activation_types": [ + ActivationType.Swiglu, + ActivationType.Geglu, + ], "enable_autotune": False, }, id="GLM4_MoE", @@ -2755,10 +2853,11 @@ def test_renormalize_routing( ], ) @pytest.mark.parametrize( - "gated_act_type", + "activation_type", [ - pytest.param(GatedActType.SwiGlu, id="SwiGlu"), - pytest.param(GatedActType.GeGlu, id="GeGlu"), + pytest.param(ActivationType.Swiglu, id="Swiglu"), + pytest.param(ActivationType.Geglu, id="Geglu"), + pytest.param(ActivationType.Relu2, id="Relu2"), ], ) def test_deepseekv3_routing( @@ -2768,7 +2867,7 @@ def test_deepseekv3_routing( moe_impl, routing_config, weight_processing, - gated_act_type, + activation_type, cache_permute_indices, ): """Test DeepSeekV3 routing configurations.""" @@ -2779,7 +2878,7 @@ def test_deepseekv3_routing( moe_impl, routing_config, weight_processing, - gated_act_type, + activation_type, cache_permute_indices, ) @@ -2830,10 +2929,10 @@ def test_deepseekv3_routing( ], ) @pytest.mark.parametrize( - "gated_act_type", + "activation_type", [ - pytest.param(GatedActType.SwiGlu, id="SwiGlu"), - pytest.param(GatedActType.GeGlu, id="GeGlu"), + pytest.param(ActivationType.Swiglu, id="Swiglu"), + pytest.param(ActivationType.Geglu, id="Geglu"), ], ) def test_topk_routing( @@ -2843,7 +2942,7 @@ def test_topk_routing( moe_impl, routing_config, weight_processing, - gated_act_type, + activation_type, cache_permute_indices, ): """Test TopK routing configuration.""" @@ -2854,7 +2953,7 @@ def test_topk_routing( moe_impl, routing_config, weight_processing, - gated_act_type, + activation_type, cache_permute_indices, ) @@ -2904,9 +3003,9 @@ def test_topk_routing( ], ) @pytest.mark.parametrize( - "gated_act_type", + "activation_type", [ - pytest.param(GatedActType.SwiGlu, id="SwiGlu"), + pytest.param(ActivationType.Swiglu, id="Swiglu"), ], ) def test_llama4_routing( @@ -2916,7 +3015,7 @@ def test_llama4_routing( moe_impl, routing_config, weight_processing, - gated_act_type, + activation_type, cache_permute_indices, ): """Test Llama4 routing configuration with FP8 per-tensor.""" @@ -2927,6 +3026,6 @@ def test_llama4_routing( moe_impl, routing_config, weight_processing, - gated_act_type, + activation_type, cache_permute_indices, ) diff --git a/tests/moe/test_trtllm_gen_routed_fused_moe.py b/tests/moe/test_trtllm_gen_routed_fused_moe.py index 7a47444081..a5272ceb36 100644 --- a/tests/moe/test_trtllm_gen_routed_fused_moe.py +++ b/tests/moe/test_trtllm_gen_routed_fused_moe.py @@ -20,7 +20,7 @@ from flashinfer import ( RoutingMethodType, - GatedActType, + ActivationType, fp4_quantize, mxfp8_quantize, ) @@ -185,7 +185,7 @@ def test_trtllm_gen_routed_fused_moe( routing_method_type.value, True, # do_finalize enable_pdl, - GatedActType.SwiGlu.value, # gated_act_type + ActivationType.Swiglu.value, # act_type None, )[0].to(torch.float) @@ -238,7 +238,7 @@ def test_trtllm_gen_routed_fused_moe( routing_method_type.value, True, # do_finalize enable_pdl, - GatedActType.SwiGlu.value, # gated_act_type + ActivationType.Swiglu.value, # act_type None, )[0].to(torch.float) diff --git a/tests/moe/utils.py b/tests/moe/utils.py index 19c01d5175..fae45f0415 100644 --- a/tests/moe/utils.py +++ b/tests/moe/utils.py @@ -17,7 +17,7 @@ import pytest import torch from enum import IntEnum -from flashinfer import GatedActType, RoutingMethodType +from flashinfer import ActivationType, RoutingMethodType from flashinfer.utils import get_compute_capability @@ -33,11 +33,25 @@ class QuantMode(IntEnum): MXINT4_BF16_BF16 = 7 +NON_GATED_ACTIVATION_SUPPORTED_QUANT_MODES = [ + QuantMode.FP4_NVFP4_NVFP4, + QuantMode.FP8_PER_TENSOR, +] + + +def is_gated_activation(activation_type: ActivationType) -> bool: + return activation_type in [ + ActivationType.Swiglu, + ActivationType.Geglu, + ActivationType.SwigluBias, + ] + + def skip_checks( moe_impl, routing_config, weight_processing, - gated_act_type, + activation_type, num_tokens, hidden_size, intermediate_size, @@ -57,24 +71,43 @@ def skip_checks( pytest.skip("Skipping zero hidden states tests for non-FP8 Block Scale MoE.") # Skip incompatible combinations - if gated_act_type == GatedActType.GeGlu and ( + if activation_type == ActivationType.Geglu and ( not is_fp4_moe or moe_impl.quant_mode != QuantMode.FP4_NVFP4_NVFP4 or routing_config["routing_method_type"] != RoutingMethodType.TopK or num_tokens > 128 ): pytest.skip( - f"Incompatible: {moe_impl.name} + {gated_act_type} + {routing_config['routing_method_type']} + {num_tokens}" + f"Incompatible: {moe_impl.name} + {activation_type} + {routing_config['routing_method_type']} + {num_tokens}" ) - elif gated_act_type == GatedActType.SwiGlu and ( + elif activation_type == ActivationType.Swiglu and ( hidden_size > 1024 or intermediate_size > 1024 ): pytest.skip( - f"Skip for testing speed: {gated_act_type} + {hidden_size} + {intermediate_size}" + f"Skip for testing speed: {activation_type} + {hidden_size} + {intermediate_size}" + ) + + compatible_activation_types = routing_config.get( + "compatible_activation_types", None + ) + if ( + compatible_activation_types is not None + and activation_type not in compatible_activation_types + ): + pytest.skip( + f"Incompatible: activation_type={activation_type} not in compatible_activation_types ({compatible_activation_types})" + ) + + if ( + not is_gated_activation(activation_type) + and moe_impl.quant_mode not in NON_GATED_ACTIVATION_SUPPORTED_QUANT_MODES + ): + pytest.skip( + f"Incompatible: {moe_impl.name} + {activation_type=} + quant_mode={moe_impl.quant_mode}: non-gated activations only supported with these quant modes: {NON_GATED_ACTIVATION_SUPPORTED_QUANT_MODES}" ) # Skip large intermediate sizes for configurations with many experts - if routing_config["num_experts"] >= 512 and intermediate_size > 512: + if routing_config["num_experts"] > 512 and intermediate_size > 512: pytest.skip( f"Skipping for testing speed: intermediate_size={intermediate_size} with {routing_config['num_experts']} experts" ) @@ -92,7 +125,7 @@ def skip_checks( f"Incompatible: intermediate_size={intermediate_size} with {routing_config['routing_method_type'].name} routing ({routing_config['num_experts']} experts)" ) - if type(moe_impl).__name__ == "MxInt4BlockScaleMoe" and ( + if moe_impl.quant_mode == QuantMode.MXINT4_BF16_BF16 and ( intermediate_size % 256 != 0 or hidden_size % 256 != 0 ): pytest.skip(