diff --git a/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py b/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py index 8ff7036dec..203faaff82 100644 --- a/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py +++ b/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py @@ -4,7 +4,7 @@ import numpy as np from flashinfer import ( RoutingMethodType, - ActivationType, + GatedActType, fp4_quantize, mxfp8_quantize, ) @@ -17,7 +17,6 @@ from flashinfer.autotuner import autotune from flashinfer.testing.utils import bench_gpu_time from flashinfer.utils import device_support_pdl -from routines.flashinfer_benchmark_utils import enum_type FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max FLOAT4_E2M1_MAX = 6.0 @@ -40,7 +39,6 @@ def bench_trtllm_gen_fused_moe_autotuner_fp8( top_k: int, warmups: int, iterations: int, - activation_type: ActivationType, ): device = torch.device("cuda:0") enable_pdl = device_support_pdl(device) @@ -99,10 +97,6 @@ def bench_trtllm_gen_fused_moe_autotuner_fp8( ) if is_block_scale: - if activation_type != ActivationType.Swiglu: - raise ValueError( - "Only Swiglu activation is supported for FP8 block scale MoE." - ) fn = lambda: trtllm_fp8_block_scale_moe( routing_logits, routing_bias, @@ -150,7 +144,6 @@ def bench_trtllm_gen_fused_moe_autotuner_fp8( RoutingMethodType.TopK.value, enable_pdl, num_tokens if tune_max_num_tokens is None else tune_max_num_tokens, - activation_type.value, ) def bench(do_autotune): @@ -182,7 +175,6 @@ def bench_trtllm_gen_fused_moe_autotuner_fp4( top_k: int, warmups: int, iterations: int, - activation_type: ActivationType, ): device = torch.device("cuda:0") enable_pdl = device_support_pdl(device) @@ -242,10 +234,6 @@ def bench_trtllm_gen_fused_moe_autotuner_fp4( w13_global_scale = 1.0 / 448.0 / 6.0 w2_global_scale = 1.0 / 448.0 / 6.0 else: - if activation_type == ActivationType.Relu2: - raise ValueError( - "Relu2 activation is supported for FP4 only with 'NvFP4xNvFP4' quant mode" - ) w13, w13_scale = fp4_quantize( w13, torch.tensor([1.0], device=device), sf_vec_size=32, sf_use_ue8m0=True ) @@ -300,7 +288,7 @@ def bench_trtllm_gen_fused_moe_autotuner_fp4( RoutingMethodType.Renormalize.value, True, enable_pdl, - activation_type.value, # act_type + GatedActType.SwiGlu.value, # gated_act_type None, num_tokens if tune_max_num_tokens is None else tune_max_num_tokens, ) @@ -360,14 +348,6 @@ def bench(do_autotune): parser.add_argument( "--iterations", type=int, default=100, help="Number of benchmark iterations" ) - parser.add_argument( - "--activation-type", - type=enum_type(ActivationType), - metavar=str([e.name for e in ActivationType]), - required=False, - default=ActivationType.Swiglu, - help=f"Type of activation function: {[e.name for e in ActivationType]}", - ) args = parser.parse_args() if args.quant_mode in ["Fp8-Per-Tensor", "Fp8-Block"]: bench_trtllm_gen_fused_moe_autotuner_fp8( @@ -380,7 +360,6 @@ def bench(do_autotune): args.top_k, args.warmups, args.iterations, - args.activation_type, ) else: bench_trtllm_gen_fused_moe_autotuner_fp4( @@ -393,5 +372,4 @@ def bench(do_autotune): args.top_k, args.warmups, args.iterations, - args.activation_type, ) diff --git a/benchmarks/routines/flashinfer_benchmark_utils.py b/benchmarks/routines/flashinfer_benchmark_utils.py index 375db471e4..b207f5cb43 100644 --- a/benchmarks/routines/flashinfer_benchmark_utils.py +++ b/benchmarks/routines/flashinfer_benchmark_utils.py @@ -1,4 +1,3 @@ -import argparse import torch from flashinfer.testing.utils import set_seed @@ -454,18 +453,3 @@ def filter_backends_by_compute_capability(backends, routine, device): f"[WARNING] {backend} for routine {routine} is not supported on compute capability {compute_capability}. Skipping." ) return backends - - -def enum_type(enum_class): - """Generic factory for argparse enum types.""" - - def converter(value): - try: - lower_name_to_member = {m.name.lower(): m for m in enum_class} - return lower_name_to_member[value.lower()] - except KeyError as e: - raise argparse.ArgumentTypeError( - f"Invalid value '{value}'. Must be one of: {', '.join([m.name for m in enum_class])}" - ) from e - - return converter diff --git a/benchmarks/routines/moe.py b/benchmarks/routines/moe.py index 16c221f483..2e4dd7bf06 100644 --- a/benchmarks/routines/moe.py +++ b/benchmarks/routines/moe.py @@ -5,7 +5,6 @@ import torch import flashinfer -from flashinfer import ActivationType from flashinfer.autotuner import autotune from flashinfer.fused_moe import ( trtllm_fp4_block_scale_moe, @@ -22,7 +21,6 @@ from .flashinfer_benchmark_utils import ( dtype_str_to_torch_dtype, - enum_type, get_device, print_perf_metrics, filter_backends_by_compute_capability, @@ -172,12 +170,12 @@ def parse_moe_args(line, parser): help="Data type of the weights (before quantization).", ) parser.add_argument( - "--activation-type", - type=enum_type(ActivationType), - metavar=str([e.name for e in ActivationType]), + "--gated_act", + type=str, required=False, - default=ActivationType.Swiglu, - help=f"Type of activation function: {[e.name for e in ActivationType]}", + default="swiglu", + choices=["swiglu", "geglu"], + help="Type of gated activation function: swiglu | geglu.", ) parser.add_argument( "--autotune", @@ -244,6 +242,13 @@ def parse_moe_args(line, parser): } args.routing_method_type = routing_method_name_to_type[args.routing_method] + # Normalize gated act type (map string to internal int expected by kernels) + gated_act_name_to_type = { + "swiglu": 0, + "geglu": 1, + } + args.gated_act_type = gated_act_name_to_type[args.gated_act] + if args.verbose >= 1: print(f"[INFO] {args = }") return args @@ -446,7 +451,7 @@ def testTrtllmFp4BlockScaleMoe(args): use_shuffled_weight = args.use_shuffled_weight weight_layout = args.weight_layout is_cuda_graph_compatible = not args.no_cuda_graph - activation_type = args.activation_type + gated_act_type = args.gated_act_type res = [] backends = ["trtllm"] @@ -605,7 +610,7 @@ def run_fp4_moe( local_num_experts=local_num_experts, routed_scaling_factor=routed_scaling_factor, routing_method_type=routing_method_type, - activation_type=activation_type.value, + gated_act_type=gated_act_type, do_finalize=True, ) @@ -710,7 +715,7 @@ def run_fp4_moe( cur_res["use_routing_scales_on_input"] = args.use_routing_scales_on_input cur_res["input_dtype"] = input_dtype cur_res["weight_dtype"] = weight_dtype - cur_res["activation_type"] = args.activation_type.name + cur_res["gated_act"] = args.gated_act res.append(cur_res) return res @@ -1466,7 +1471,6 @@ def run_fp8_per_tensor_moe( output1_scales_gate_scalar, gemm2_weights_fp8, output2_scales_scalar, - activation_type, ): # Note: FP8 per-tensor MOE expects int64_t for n_group/topk_group, not Optional[int64_t] # So we convert None to 0 to indicate "no groups" mode @@ -1489,7 +1493,6 @@ def run_fp8_per_tensor_moe( routed_scaling_factor=routed_scaling_factor, use_routing_scales_on_input=use_routing_scales_on_input, routing_method_type=routing_method_type, - activation_type=activation_type.value, ) # Benchmark timing @@ -1510,7 +1513,6 @@ def run_fp8_per_tensor_moe( output1_scales_gate_scalar, gemm2_weights_fp8, output2_scales_scalar, - args.activation_type, ), ) @@ -1562,7 +1564,6 @@ def run_fp8_per_tensor_moe( cur_res["use_routing_scales_on_input"] = use_routing_scales_on_input cur_res["input_dtype"] = input_dtype cur_res["weight_dtype"] = weight_dtype - cur_res["activation_type"] = args.activation_type.name res.append(cur_res) return res diff --git a/csrc/trtllm_batched_gemm_runner.cu b/csrc/trtllm_batched_gemm_runner.cu index f3eae5e9e3..f99e766e86 100644 --- a/csrc/trtllm_batched_gemm_runner.cu +++ b/csrc/trtllm_batched_gemm_runner.cu @@ -101,16 +101,14 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner( options.mTransposeMmaOutput == mOptions.transposeMmaOutput && (!doesRouteImplUseNoRoute(options.mRouteImpl)) == mOptions.routeAct && options.mFusedAct == mOptions.fusedAct && options.mIsStaticBatch == mOptions.staticBatch && - tileSize == mOptions.tileSize && options.mUseShuffledMatrix == mOptions.useShuffledMatrix && + tileSize == mOptions.tileSize && + options.mUseShuffledMatrix == mOptions.useShuffledMatrixA && options.mLayoutA == mOptions.weightLayout) { if (options.mFusedAct) { if (options.mActType != static_cast(mOptions.actType)) { continue; } } - if ((int64_t)options.mEltwiseActType != (int64_t)mOptions.eltwiseActType) { - continue; - } if (mOptions.transposeMmaOutput && options.mEpilogueTileM == mOptions.epilogueTileM) { mPassingConfigIndices.push_back(i); @@ -124,8 +122,6 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner( << ", mDtypeB: " << tg::dtypeToString(mOptions.dtypeB) << ", mDtypeC: " << tg::dtypeToString(mOptions.dtypeC) << ", mUseDeepSeekFp8: " << mOptions.deepSeekFp8 - << ", mActType: " << (int64_t)mOptions.actType - << ", mEltwiseActType: " << (int64_t)mOptions.eltwiseActType << ", mTransposeMmaOutput: " << mOptions.transposeMmaOutput << ", mRouteAct: " << mOptions.routeAct << ", mFusedAct: " << mOptions.fusedAct << ", mIsStaticBatch: " << mOptions.staticBatch << ", mTileSize: " << mOptions.tileSize; @@ -223,8 +219,6 @@ void TrtllmGenBatchedGemmRunner::run( gemmData.mInputBuffers.mPtrSfB = mOptions.transposeMmaOutput ? sfA : sfB; gemmData.mInputBuffers.mPtrScaleC = scaleC; gemmData.mInputBuffers.mPtrScaleGate = scaleGateC; - // For simplicity pass set scaleAct to scaleGateC - gemmData.mInputBuffers.mPtrScaleAct = scaleGateC; gemmData.mInputBuffers.mPtrPerTokenSfA = mOptions.transposeMmaOutput ? perTokensSfB : perTokensSfA; gemmData.mInputBuffers.mPtrPerTokenSfB = diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index 2b0efb7cc9..cecf4efb7a 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -36,7 +36,7 @@ namespace flashinfer { namespace btg = batchedGemm::trtllm::gen; -using tensorrt_llm::kernels::trtllmgen_moe::MoE::ActivationType; +using tensorrt_llm::kernels::trtllmgen_moe::MoE::GatedActType; using tensorrt_llm::kernels::trtllmgen_moe::Routing::RoutingMethodType; using tvm::ffi::Array; using tvm::ffi::Optional; @@ -109,7 +109,7 @@ class FusedMoeLauncher { btg::Dtype mDtypeWeights{btg::Dtype::Bfloat16}; btg::Dtype mRoutingBiasDtype{ btg::Dtype::Bfloat16}; // Dtype for expert weights in routing, based on routing bias - ActivationType activation_type{ActivationType::Swiglu}; + GatedActType gated_act_type{GatedActType::SwiGlu}; public: // Constructor that initializes all TensorView members @@ -134,14 +134,14 @@ class FusedMoeLauncher { weight_layout{batchedGemm::gemm::MatrixLayout::MajorK}, mDtypeAct{btg::Dtype::Bfloat16}, mDtypeWeights{btg::Dtype::Bfloat16}, - activation_type{ActivationType::Swiglu} {} + gated_act_type{GatedActType::SwiGlu} {} protected: // Initialize common data necessary for later. // May throw exception from TVM_FFI_ICHECK. void init_common(std::unique_ptr&& args, int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, - int64_t weight_layout, ActivationType activation_type); + int64_t weight_layout, int64_t gated_act_type); // Routing logits [num_tokens, num_experts] void check_routing_logits_shape() const { @@ -305,9 +305,10 @@ class FusedMoeLauncher { (int32_t)tile_tokens_dim, this->use_shuffled_weight, this->weight_layout); } else { - moe_runner = std::make_unique( - this->mDtypeAct, this->mDtypeWeights, args->mUseDeepSeekFp8, (int32_t)tile_tokens_dim, - this->activation_type, this->use_shuffled_weight, this->weight_layout); + moe_runner = std::make_unique(this->mDtypeAct, this->mDtypeWeights, + args->mUseDeepSeekFp8, (int32_t)tile_tokens_dim, + static_cast(this->gated_act_type), + this->use_shuffled_weight, this->weight_layout); } if (moe_tactic == -1) { @@ -376,7 +377,7 @@ class FusedMoeLauncher { void FusedMoeLauncher::init_common( std::unique_ptr&& args, int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, - int64_t weight_layout, ActivationType activation_type) { + int64_t weight_layout, int64_t gated_act_type) { // Check devicearchitecture: Blackwell (SM 10.x) required auto device = hidden_states.device().device_id; int major = 0, minor = 0; @@ -399,7 +400,9 @@ void FusedMoeLauncher::init_common( TVM_FFI_ICHECK(0 <= weight_layout && weight_layout <= 2) << "the value of weight_layout is not recognized"; this->weight_layout = static_cast(weight_layout); - this->activation_type = activation_type; + TVM_FFI_ICHECK(0 <= gated_act_type && gated_act_type <= 1) + << "the value of gated_act_type is not recognized"; + this->gated_act_type = static_cast(gated_act_type); } class Bf16MoeLauncher : public FusedMoeLauncher { @@ -416,12 +419,12 @@ class Bf16MoeLauncher : public FusedMoeLauncher { void init(std::unique_ptr&& args, int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout) { - constexpr ActivationType activation_type = - ActivationType::Swiglu; // not exposed in api for now + constexpr int64_t gated_act_type = + static_cast(GatedActType::SwiGlu); // not exposed in api for now // Do base class init and perform common checks FusedMoeLauncher::init_common(std::move(args), tile_tokens_dim, routing_method_type, - use_shuffled_weight, weight_layout, activation_type); + use_shuffled_weight, weight_layout, gated_act_type); } void check_routing() const override { @@ -486,7 +489,7 @@ class Bf16MoeLauncher : public FusedMoeLauncher { static Array> getValidConfigs(int64_t top_k, int64_t hidden_size, int64_t intermediate_size, int64_t num_local_experts, - int64_t num_tokens, int64_t act_type, + int64_t num_tokens, int64_t gated_act_type, bool use_shuffled_weight, int64_t weight_layout) { Array> valid_configs; @@ -499,7 +502,7 @@ class Bf16MoeLauncher : public FusedMoeLauncher { btg::Dtype::Bfloat16, // dtype_act btg::Dtype::Bfloat16, // dtype_weights false, // useDeepSeekFp8 - tile_N, static_cast(act_type), use_shuffled_weight, + tile_N, static_cast(gated_act_type), use_shuffled_weight, static_cast(weight_layout)); auto cfgs = moe_runner->getValidConfigIndices(top_k, hidden_size, intermediate_size, @@ -532,8 +535,10 @@ class Fp8PerTensorLauncher : public FusedMoeLauncher { void init(std::unique_ptr&& args, int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, - int64_t weight_layout, bool use_routing_scales_on_input_param, - ActivationType activation_type) { + int64_t weight_layout, bool use_routing_scales_on_input_param) { + constexpr int64_t gated_act_type = + static_cast(GatedActType::SwiGlu); // not exposed in api for now + this->use_routing_scales_on_input = use_routing_scales_on_input_param; auto dtype = hidden_states.dtype(); @@ -549,7 +554,7 @@ class Fp8PerTensorLauncher : public FusedMoeLauncher { mDtypeWeights = btg::Dtype::E4m3; FusedMoeLauncher::init_common(std::move(args), tile_tokens_dim, routing_method_type, - use_shuffled_weight, weight_layout, activation_type); + use_shuffled_weight, weight_layout, gated_act_type); } void check_routing() const override { FusedMoeLauncher::check_routing_common(); } @@ -677,7 +682,7 @@ class Fp8PerTensorLauncher : public FusedMoeLauncher { public: static Array> getValidConfigs(int64_t top_k, int64_t hidden_size, int64_t intermediate_size, int64_t num_local_experts, - int64_t num_tokens, int64_t act_type, + int64_t num_tokens, int64_t gated_act_type, bool use_shuffled_weight, int64_t weight_layout, btg::Dtype dtype_act, btg::Dtype dtype_weights) { Array> valid_configs; @@ -690,7 +695,7 @@ class Fp8PerTensorLauncher : public FusedMoeLauncher { auto moe_runner = std::make_unique( dtype_act, dtype_weights, false, // useDeepSeekFp8 - tile_N, static_cast(act_type), use_shuffled_weight, + tile_N, static_cast(gated_act_type), use_shuffled_weight, static_cast(weight_layout)); auto cfgs = moe_runner->getValidConfigIndices(top_k, hidden_size, intermediate_size, @@ -727,7 +732,7 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { void init(std::unique_ptr&& args, int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout) { - constexpr ActivationType activation_type = ActivationType::Swiglu; + constexpr int64_t gated_act_type = static_cast(GatedActType::SwiGlu); mDtypeAct = btg::Dtype::E4m3; mDtypeWeights = btg::Dtype::E4m3; @@ -747,7 +752,7 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { args->mDtypeOut = btg::Dtype::Bfloat16; FusedMoeLauncher::init_common(std::move(args), tile_tokens_dim, routing_method_type, - use_shuffled_weight, weight_layout, activation_type); + use_shuffled_weight, weight_layout, gated_act_type); } void check_routing() const override { @@ -1044,7 +1049,8 @@ class MxInt4BlockScaleLauncher : public FusedMoeLauncher { FusedMoeLauncher::init_common( std::move(args), tile_tokens_dim, routing_method_type, /*use_shuffled_weight=*/true, - static_cast(batchedGemm::gemm::MatrixLayout::BlockMajorK), ActivationType::Swiglu); + static_cast(batchedGemm::gemm::MatrixLayout::BlockMajorK), + static_cast(GatedActType::SwiGlu)); } void check_routing() const override { FusedMoeLauncher::check_routing_common(); } @@ -1147,8 +1153,8 @@ class MxInt4BlockScaleLauncher : public FusedMoeLauncher { auto moe_runner = std::make_unique( btg::Dtype::Bfloat16, btg::Dtype::MxInt4, false, // useDeepSeekFp8 - tile_N, ActivationType::Swiglu, - /*useShuffledMatrix*/ true, batchedGemm::gemm::MatrixLayout::BlockMajorK); + tile_N, GatedActType::SwiGlu, + /*useShuffledMatrixA*/ true, batchedGemm::gemm::MatrixLayout::BlockMajorK); auto cfgs = moe_runner->getValidConfigIndices(top_k, hidden_size, intermediate_size, num_local_experts, num_tokens); @@ -1202,7 +1208,7 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { void init(std::unique_ptr&& args, int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, - int64_t weight_layout, ActivationType activation_type, btg::Dtype dtype_act, + int64_t weight_layout, int64_t gated_act_type, btg::Dtype dtype_act, btg::Dtype dtype_weights) { static const std::tuple device_props = [this] { int major, minor; @@ -1226,7 +1232,7 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { mDtypeWeights = dtype_weights; FusedMoeLauncher::init_common(std::move(args), tile_tokens_dim, routing_method_type, - use_shuffled_weight, weight_layout, activation_type); + use_shuffled_weight, weight_layout, gated_act_type); } void check_routing() const override { @@ -1446,7 +1452,7 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { static Array> getValidConfigs(int64_t top_k, int64_t hidden_size, int64_t intermediate_size, int64_t num_local_experts, - int64_t num_tokens, int64_t act_type, + int64_t num_tokens, int64_t gated_act_type, btg::Dtype dtype_act, btg::Dtype dtype_weights) { Array> valid_configs; @@ -1458,8 +1464,8 @@ class FP4BlockScaleLauncher : public FusedMoeLauncher { auto moe_runner = std::make_unique( dtype_act, dtype_weights, false, // useDeepSeekFp8 - tile_N, static_cast(act_type), - /*useShuffledMatrix*/ true); // FP4 uses shuffled weights + tile_N, static_cast(gated_act_type), + /*useShuffledMatrixA*/ true); // FP4 uses shuffled weights auto cfgs = moe_runner->getValidConfigIndices(top_k, hidden_size, intermediate_size, num_local_experts, num_tokens); @@ -1552,10 +1558,9 @@ Tensor trtllm_fp8_per_tensor_scale_moe( Optional n_group, Optional topk_group, int64_t intermediate_size, int64_t local_expert_offset, int64_t local_num_experts, Optional routed_scaling_factor, bool use_routing_scales_on_input, int64_t routing_method_type, bool enable_pdl, - Array config_index, int64_t activation_type) { + Array config_index) { // Basic type validation auto dtype = hidden_states.dtype(); - auto activation = static_cast(activation_type); if (use_routing_scales_on_input) { TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_bfloat16) << "routing_logits must be bfloat16."; } else if (static_cast(routing_method_type) == RoutingMethodType::DeepSeekV3) { @@ -1580,7 +1585,7 @@ Tensor trtllm_fp8_per_tensor_scale_moe( auto const hidden_size = hidden_states.size(1); // Use default values that match the original function behavior - bool use_shuffled_weight = true; // Original uses /*useShuffledMatrix*/ true + bool use_shuffled_weight = true; // Original uses /*useShuffledMatrixA*/ true int64_t weight_layout = 0; // Default to MajorK // Calculate supported tile sizes @@ -1612,7 +1617,7 @@ Tensor trtllm_fp8_per_tensor_scale_moe( routing_logits, routing_bias, hidden_states, gemm1_weights, output1_scales_scalar, output1_scales_gate_scalar, gemm2_weights, output2_scales_scalar); launcher->init(std::move(args), curr_tile_N, routing_method_type, use_shuffled_weight, - weight_layout, use_routing_scales_on_input, activation); + weight_layout, use_routing_scales_on_input); launchers_map[curr_tile_N] = std::move(launcher); } @@ -1745,7 +1750,7 @@ Array trtllm_fp4_block_scale_moe( Optional output2_scales_scalar, int64_t num_experts, int64_t top_k, Optional n_group, Optional topk_group, int64_t intermediate_size, int64_t local_expert_offset, int64_t local_num_experts, Optional routed_scaling_factor, - int64_t routing_method_type, bool do_finalize, bool enable_pdl, int64_t act_type, + int64_t routing_method_type, bool do_finalize, bool enable_pdl, int64_t gated_act_type, TensorView output, Array config_index) { // Determine data types based on input format int const num_tokens = hidden_states.size(0); @@ -1756,11 +1761,8 @@ Array trtllm_fp4_block_scale_moe( if (hidden_states_scale.has_value()) { hidden_states_scale_vec_size = (num_tokens * hidden_size) / hidden_states_scale.value().numel(); } - int64_t intermediate_size_factor = - isGatedActivation(static_cast(act_type)) ? 2 : 1; int weight_scale_vec_size = - (local_num_experts * intermediate_size * intermediate_size_factor * hidden_size) / - gemm1_weights_scale.numel(); + (local_num_experts * intermediate_size * 2 * hidden_size) / gemm1_weights_scale.numel(); TVM_FFI_ICHECK(weight_scale_vec_size == 16 || weight_scale_vec_size == 32) << "unsupported weight_scale_vec_size."; @@ -1853,8 +1855,7 @@ Array trtllm_fp4_block_scale_moe( gemm2_weights_scale, gemm2_bias, output1_scales_scalar, output1_scales_gate_scalar, output2_scales_scalar, expert_indices, expert_weights); launcher->init(std::move(args), curr_tile_N, routing_method_type, /*use_shuffled_weight=*/true, - /*weight_layout=*/0, static_cast(act_type), mDtypeAct, - mDtypeWeights); + /*weight_layout=*/0, gated_act_type, mDtypeAct, mDtypeWeights); launchers_map[curr_tile_N] = std::move(launcher); } @@ -1967,7 +1968,7 @@ Array trtllm_mxint4_block_scale_moe( Array> trtllm_get_valid_moe_configs( int64_t const dtype_act_, int64_t const dtype_weights_, bool const useDeepSeekFp8, int64_t const top_k, int64_t const hidden_size, int64_t const intermediate_size, - int64_t const num_local_experts, int64_t const act_type, bool const use_shuffled_weight, + int64_t const num_local_experts, int64_t const gated_act_type, bool const use_shuffled_weight, int64_t const weight_layout, int64_t const num_tokens) { auto dtype_act = static_cast(dtype_act_); auto dtype_weights = static_cast(dtype_weights_); @@ -1980,7 +1981,7 @@ Array> trtllm_get_valid_moe_configs( if (dtype_act == btg::Dtype::Bfloat16 && dtype_weights == btg::Dtype::Bfloat16) { // BF16 MoE return Bf16MoeLauncher::getValidConfigs(top_k, hidden_size, intermediate_size, - num_local_experts, num_tokens, act_type, + num_local_experts, num_tokens, gated_act_type, use_shuffled_weight, weight_layout); } else if (dtype_act == btg::Dtype::E4m3 && dtype_weights == btg::Dtype::E4m3) { @@ -1988,7 +1989,7 @@ Array> trtllm_get_valid_moe_configs( if (!useDeepSeekFp8) { // FP8 per-tensor scale return Fp8PerTensorLauncher::getValidConfigs( - top_k, hidden_size, intermediate_size, num_local_experts, num_tokens, act_type, + top_k, hidden_size, intermediate_size, num_local_experts, num_tokens, gated_act_type, use_shuffled_weight, weight_layout, dtype_act, dtype_weights); } else { // FP8 block scale @@ -1999,7 +2000,7 @@ Array> trtllm_get_valid_moe_configs( } else if (dtype_weights == btg::Dtype::E2m1 || dtype_weights == btg::Dtype::MxE2m1) { // FP4 block scale return FP4BlockScaleLauncher::getValidConfigs(top_k, hidden_size, intermediate_size, - num_local_experts, num_tokens, act_type, + num_local_experts, num_tokens, gated_act_type, dtype_act, dtype_weights); } diff --git a/csrc/trtllm_fused_moe_routing_deepseek.cu b/csrc/trtllm_fused_moe_routing_deepseek.cu index 422194cf7f..21faec8ec7 100644 --- a/csrc/trtllm_fused_moe_routing_deepseek.cu +++ b/csrc/trtllm_fused_moe_routing_deepseek.cu @@ -14,7 +14,6 @@ * limitations under the License. */ -#include #include #include "flashinfer/exception.h" @@ -26,14 +25,10 @@ namespace routingDeepSeek { //////////////////////////////////////////////////////////////////////////////////////////////////// -static constexpr int NumNemotronExperts = 512; static constexpr int NumKimiK2Experts = 384; static constexpr int NumDeepseekExperts = 256; -static constexpr int MaxSupportedExpertCount = - std::max({NumNemotronExperts, NumKimiK2Experts, NumDeepseekExperts}); static constexpr int NumTopGroupScores = 2; -static constexpr int DefaultMaxNumTopExperts = 8; -static constexpr int MaxSupportedTopExperts = 22; +static constexpr int MaxNumTopExperts = 8; static constexpr int MaxNumTopGroups = 4; static constexpr int MaxNumGroups = 8; @@ -122,8 +117,8 @@ __global__ void routingMainKernel(KernelParams params) { int32_t topGroupIdx[MaxNumTopGroups]; float expertScoreGroup[MaxNumTopGroups]; int32_t expertIdxGroup[MaxNumTopGroups]; - float topScores[KernelParams::MaxNumTopExperts]; // bound of params.mTopK - int32_t topExperts[KernelParams::MaxNumTopExperts]; + float topScores[MaxNumTopExperts]; // bound of params.mTopK + int32_t topExperts[MaxNumTopExperts]; if constexpr (KernelParams::UseGroups) { topk::reduceTopK(warp, topExpGroupScores, topExpGroupIdx, scoreBias, threadExpert, @@ -159,8 +154,7 @@ __global__ void routingMainKernel(KernelParams params) { // params.mNumExpertsPerGroup // => expertIdxGroup[ii] < params.mNumExperts <= NumThreads, // so the access is safe here - expertScoreGroup[ii] = (ii < params.mNumLimitedGroups) && - (groupIdx < params.mNumExpertGroups) && expertSelected + expertScoreGroup[ii] = groupIdx < params.mNumExpertGroups && expertSelected ? smemScoreBias[expertIdxGroup[ii]] : invalidScoreFloat; } @@ -172,7 +166,7 @@ __global__ void routingMainKernel(KernelParams params) { // without groups, each thread just takes `MaxNumTopGroups` experts int constexpr NumExpertWarps = (KernelParams::MaxNumExperts - 1) / topk::MaxNumExpertsUnit + 1; - int constexpr NumInterTopK = NumExpertWarps * KernelParams::MaxNumTopExperts; + int constexpr NumInterTopK = NumExpertWarps * MaxNumTopExperts; __shared__ float __attribute((aligned(128))) smemInterTopScores[NumInterTopK]; __shared__ int32_t __attribute((aligned(128))) smemInterTopExperts[NumInterTopK]; if (warpIdx < NumExpertWarps) { @@ -189,20 +183,13 @@ __global__ void routingMainKernel(KernelParams params) { /* minValue */ invalidScoreFloat, params.mTopK); if (laneIdx < params.mTopK) { - smemInterTopScores[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] = - topScores[laneIdx]; - smemInterTopExperts[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] = - topExperts[laneIdx]; - } else if (laneIdx >= params.mTopK && laneIdx < KernelParams::MaxNumTopExperts) { - smemInterTopScores[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] = - invalidScoreFloat; - smemInterTopExperts[warpIdx * KernelParams::MaxNumTopExperts + laneIdx] = - MaxSupportedExpertCount - 1; + smemInterTopScores[warpIdx * MaxNumTopExperts + laneIdx] = topScores[laneIdx]; + smemInterTopExperts[warpIdx * MaxNumTopExperts + laneIdx] = topExperts[laneIdx]; } } __syncthreads(); if (warpIdx == 0) { - int constexpr NumInterTopKPerThread = (NumInterTopK - 1) / WarpSize + 1; + int constexpr NumInterTopKPerThread = (NumInterTopK * NumExpertWarps - 1) / WarpSize + 1; float intermidiateScore[NumInterTopKPerThread]; int32_t intermidiateExpert[NumInterTopKPerThread]; for (int i = laneIdx; i < NumInterTopKPerThread * WarpSize; i += WarpSize) { @@ -283,7 +270,7 @@ __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) cudaGridDependencySynchronize(); } routingPermutation(params, nullptr, warpIdx, clusterBlockRank); } #else @@ -506,8 +493,6 @@ int constexpr getMaxNumExperts(int32_t numExperts) { return NumDeepseekExperts; } else if (numExperts <= NumKimiK2Experts) { return NumKimiK2Experts; - } else if (numExperts <= NumNemotronExperts) { - return NumNemotronExperts; } else { TLLM_LOG_ERROR("Unsupported numExperts"); return 0; @@ -519,23 +504,13 @@ int constexpr getMaxNumExperts(int32_t numExperts) { extraFlag) \ if (data.mNumExperts <= topk::MaxNumExpertsUnit) { \ LAUNCH_ROUTING_DEEPSEEK_IMPL(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag, topk::MaxNumExpertsUnit, \ - DefaultMaxNumTopExperts); \ + stream, extraFlag, topk::MaxNumExpertsUnit); \ } else if (data.mNumExperts <= NumDeepseekExperts) { \ LAUNCH_ROUTING_DEEPSEEK_IMPL(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag, NumDeepseekExperts, DefaultMaxNumTopExperts); \ + stream, extraFlag, NumDeepseekExperts); \ } else if (data.mNumExperts <= NumKimiK2Experts) { \ LAUNCH_ROUTING_DEEPSEEK_IMPL(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag, NumKimiK2Experts, DefaultMaxNumTopExperts); \ - } else if (data.mNumExperts <= NumNemotronExperts) { \ - if (data.mTopK <= DefaultMaxNumTopExperts) { \ - LAUNCH_ROUTING_DEEPSEEK_IMPL(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag, NumNemotronExperts, \ - DefaultMaxNumTopExperts); \ - } else { \ - LAUNCH_ROUTING_DEEPSEEK_IMPL(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag, NumNemotronExperts, MaxSupportedTopExperts); \ - } \ + stream, extraFlag, NumKimiK2Experts); \ } else { \ TLLM_LOG_ERROR("Unsupported numExperts"); \ } @@ -557,20 +532,20 @@ void runImpl(Data& data, void* stream) { FLASHINFER_CHECK(data.mNumLimitedGroups <= MaxNumTopGroups, "Routing kernel expects <= %d top groups, got %d", MaxNumTopGroups, data.mNumLimitedGroups); - FLASHINFER_CHECK(data.mTopK <= MaxSupportedTopExperts, - "Routing kernel expects topK experts <= %d, got %d", MaxSupportedTopExperts, + FLASHINFER_CHECK(data.mTopK <= MaxNumTopExperts, + "Routing kernel expects topK experts <= %d, got %d", MaxNumTopExperts, data.mTopK); FLASHINFER_CHECK(data.mTopK <= WarpSize, "Routing kernel expects top K <= warp size, got %d", data.mTopK); FLASHINFER_CHECK(data.mTopK * data.mNumLimitedGroups <= WarpSize, "Routing kernel expects top K * top groups <= warp size (for now), got %d * %d", data.mTopK, data.mNumLimitedGroups); - FLASHINFER_CHECK(data.mNumExperts >= MaxSupportedTopExperts, - "Routing kernel expects %d to be at most #experts %d", MaxSupportedTopExperts, + FLASHINFER_CHECK(data.mNumExperts >= MaxNumTopExperts, + "Routing kernel expects %d to be at most #experts %d", MaxNumTopExperts, data.mNumExperts); - FLASHINFER_CHECK(data.mNumExperts <= MaxSupportedExpertCount, + FLASHINFER_CHECK(data.mNumExperts <= NumKimiK2Experts, "Routing kernel expects #experts %d <= #threads %d", data.mNumExperts, - MaxSupportedExpertCount); + NumKimiK2Experts); FLASHINFER_CHECK(data.mNumExpertGroups >= data.mNumLimitedGroups, "Routing kernel expects top groups %d to be limited by #expert groups %d", data.mNumLimitedGroups, data.mNumExpertGroups); @@ -585,6 +560,10 @@ void runImpl(Data& data, void* stream) { data.mNumExperts / data.mNumExpertGroups <= WarpSize, "Routing kernel expects #experts per group <= warp size, got %d, data.mNumExpertGroups %d", data.mNumExperts / data.mNumExpertGroups, data.mNumExpertGroups); + } else { + FLASHINFER_CHECK(data.mTopK <= topk::MaxNumTopK, + "Routing kernel expects top K %d to be <= #warps %d", data.mTopK, + topk::MaxNumTopK); } FLASHINFER_CHECK(data.mNumExperts % 4 == 0, "Routing kernel expects #experts %d to be a multiple of 4.", data.mNumExperts); @@ -619,7 +598,7 @@ void runImpl(Data& data, void* stream) { int const maxTokensCoop = (numBlocksCoop * numThreadsHist * 64) / data.mTopK; if (data.mPtrTopKIds == nullptr) { int const numThreadsMain = - max(data.mNumExpertGroups * WarpSize, getMaxNumExperts(data.mNumExperts)); + data.mNumExperts < NumDeepseekExperts ? NumDeepseekExperts : NumKimiK2Experts; LAUNCH_ROUTING_DEEPSEEK(data, /*coopLaunch=*/false, routingMainKernel, numBlocks, numThreadsMain, /*smemSize=*/0, // No dynamic smem diff --git a/csrc/trtllm_fused_moe_runner.cu b/csrc/trtllm_fused_moe_runner.cu index e3615fa1c4..b5ff5757c9 100644 --- a/csrc/trtllm_fused_moe_runner.cu +++ b/csrc/trtllm_fused_moe_runner.cu @@ -60,7 +60,7 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 bool useRoutingScalesOnInput, bool useDeepSeekFp8, RoutingMethodType routingMethodType, cudaStream_t stream) { if (routingMethodType == RoutingMethodType::DeepSeekV3) { - FLASHINFER_CHECK(topK <= 22, "For DeepSeek routing method, must have topK <= 22"); + FLASHINFER_CHECK(topK <= 8, "For DeepSeek routing method, must have topK <= 8"); FLASHINFER_CHECK(topkGroup <= 4, "For DeepSeek routing method, must have topkGroup <= 4"); moe::dev::routing::routingDeepSeek::Data routingData; routingData.mDtypeExpW = @@ -189,49 +189,13 @@ void Runner::run(void* routingLogits, void* routingBias, int32_t numTokens, int3 namespace PermuteGemm1 { -using tensorrt_llm::kernels::trtllmgen_moe::MoE::ActivationType; -using tensorrt_llm::kernels::trtllmgen_moe::MoE::isGatedActivation; -using tensorrt_llm::kernels::trtllmgen_moe::MoE::serializeActivationType; - -static inline ActType activationTypeToGatedActType(ActivationType actType) { - switch (actType) { - case ActivationType::Swiglu: - return ActType::SwiGlu; - case ActivationType::Geglu: - return ActType::GeGlu; - default: - FLASHINFER_CHECK(false, "Unsupported gated activation type ", - serializeActivationType(actType), " of enum ", - static_cast(actType)); - } - return ActType::SwiGlu; -} - -static inline EltwiseActType activationTypeToEltwiseActType(ActivationType actType) { - switch (actType) { - case ActivationType::Relu2: - return EltwiseActType::Relu2; - case ActivationType::Identity: - return EltwiseActType::None; - default: - FLASHINFER_CHECK(false, "Unsupported eltwise activation type ", - serializeActivationType(actType), " of enum ", - static_cast(actType)); - } - return EltwiseActType::None; -} - tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions getOptions( btg::Dtype dtypeAct, btg::Dtype dtypeWeights, int32_t tileTokensDim, bool useDeepSeekFp8, - ActivationType activationType, bool useShuffledMatrix, + MoE::GatedActType gatedActType, bool useShuffledMatrixA, batchedGemm::gemm::MatrixLayout weightLayout) { - int64_t actTypeInt = static_cast(activationType); - FLASHINFER_CHECK( - 0 <= actTypeInt && actTypeInt < static_cast(ActivationType::InvalidType), - "Unknown activation type", serializeActivationType(activationType), "of enum", actTypeInt); - bool isGatedAct = isGatedActivation(activationType); - if (isGatedAct) { - ActType actType = activationTypeToGatedActType(activationType); + if (gatedActType == MoE::GatedActType::SwiGlu || gatedActType == MoE::GatedActType::GeGlu) { + ActType actType = + (gatedActType == MoE::GatedActType::SwiGlu) ? ActType::SwiGlu : ActType::GeGlu; tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions options = { // Swap A and B dtypes because transposeMmaOutput is hardcoded to true .dtypeA = dtypeWeights, @@ -245,40 +209,24 @@ tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions getOptions( .transposeMmaOutput = true, .tileSize = tileTokensDim, .epilogueTileM = useDeepSeekFp8 ? 64 : 128, - .useShuffledMatrix = useShuffledMatrix, + .useShuffledMatrixA = useShuffledMatrixA, .weightLayout = weightLayout}; return options; } else { - EltwiseActType actType = activationTypeToEltwiseActType(activationType); - tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions options = { - // Swap A and B dtypes because transposeMmaOutput is hardcoded to true - .dtypeA = dtypeWeights, - .dtypeB = dtypeAct, - .dtypeC = dtypeAct, - .eltwiseActType = actType, - .deepSeekFp8 = useDeepSeekFp8, - .fusedAct = false, - .routeAct = true, - .staticBatch = false, - .transposeMmaOutput = true, - .tileSize = tileTokensDim, - .epilogueTileM = 128, - .useShuffledMatrix = useShuffledMatrix, - .weightLayout = weightLayout}; - return options; + FLASHINFER_CHECK(false, "Unimplemented gated act type ", + MoE::serializeGatedActType(gatedActType), " of enum ", (int)gatedActType); } } Runner::Runner(btg::Dtype dtypeAct, btg::Dtype dtypeWeights, bool useDeepSeekFp8, int tileTokensDim, - ActivationType activationType, bool useShuffledMatrix, + MoE::GatedActType gatedActType, bool useShuffledMatrixA, batchedGemm::gemm::MatrixLayout weightLayout) : mDtypeAct(dtypeAct), mDtypeWeights(dtypeWeights), mTileTokensDim(tileTokensDim), mRunner(tensorrt_llm::kernels::TrtllmGenBatchedGemmRunner( - getOptions(mDtypeAct, mDtypeWeights, mTileTokensDim, useDeepSeekFp8, activationType, - useShuffledMatrix, weightLayout))), - mActType(activationType) {} + getOptions(mDtypeAct, mDtypeWeights, mTileTokensDim, useDeepSeekFp8, gatedActType, + useShuffledMatrixA, weightLayout))) {} void Runner::run(void* hiddenState, void* hiddenStateScale, void* weights, void* weightsScale, void* expertWeights, float* outputScalesScalar, float* outputScalesGateScalar, @@ -291,14 +239,12 @@ void Runner::run(void* hiddenState, void* hiddenStateScale, void* weights, void* int device, cudaStream_t stream, int32_t configIndex, bool enable_pdl) { auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); - int32_t intermediateSizeFactor = (isGatedActivation(mActType) ? 2 : 1); - mRunner.run(numTokens, intermediateSizeFactor * intermediateSize, hiddenSize, {}, numTokens, - numExperts, maxNumCtasInBatchDim, hiddenState, hiddenStateScale, weights, - weightsScale, expertWeights, /* perTokensSfB */ nullptr, outputScalesScalar, - outputScalesGateScalar, ptrBias, ptrAlpha, ptrBeta, ptrClampLimit, output, - outputScale, permutedIdxToTokenIdx, ptrTotalNumPaddedTokens, ptrCtaIdxXyToBatchIdx, - ptrCtaIdxXyToMnLimit, ptrNumNonExitingCtas, bmm1Workspace, stream, device, - configIndex, enable_pdl); + mRunner.run(numTokens, 2 * intermediateSize, hiddenSize, {}, numTokens, numExperts, + maxNumCtasInBatchDim, hiddenState, hiddenStateScale, weights, weightsScale, + expertWeights, /* perTokensSfB */ nullptr, outputScalesScalar, outputScalesGateScalar, + ptrBias, ptrAlpha, ptrBeta, ptrClampLimit, output, outputScale, permutedIdxToTokenIdx, + ptrTotalNumPaddedTokens, ptrCtaIdxXyToBatchIdx, ptrCtaIdxXyToMnLimit, + ptrNumNonExitingCtas, bmm1Workspace, stream, device, configIndex, enable_pdl); } size_t Runner::getWorkspaceSizeInBytes(int32_t topK, int32_t hiddenSize, int32_t intermediateSize, @@ -306,10 +252,8 @@ size_t Runner::getWorkspaceSizeInBytes(int32_t topK, int32_t hiddenSize, int32_t int32_t configIndex) const { auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); - int32_t intermediateSizeFactor = (isGatedActivation(mActType) ? 2 : 1); - return mRunner.getWorkspaceSizeInBytes(numTokens, intermediateSizeFactor * intermediateSize, - hiddenSize, {}, numTokens, numExperts, - maxNumCtasInBatchDim, configIndex); + return mRunner.getWorkspaceSizeInBytes(numTokens, 2 * intermediateSize, hiddenSize, {}, numTokens, + numExperts, maxNumCtasInBatchDim, configIndex); } int32_t Runner::getDefaultValidConfigIndex(int32_t topK, int32_t hiddenSize, @@ -317,10 +261,8 @@ int32_t Runner::getDefaultValidConfigIndex(int32_t topK, int32_t hiddenSize, int32_t numTokens) const { auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); - int32_t intermediateSizeFactor = (isGatedActivation(mActType) ? 2 : 1); - return mRunner.getDefaultValidConfigIndex(numTokens, intermediateSizeFactor * intermediateSize, - hiddenSize, {}, numTokens, numExperts, - maxNumCtasInBatchDim); + return mRunner.getDefaultValidConfigIndex(numTokens, 2 * intermediateSize, hiddenSize, {}, + numTokens, numExperts, maxNumCtasInBatchDim); } bool Runner::isValidConfigIndex(int32_t configIndex, int32_t topK, int32_t hiddenSize, @@ -329,10 +271,9 @@ bool Runner::isValidConfigIndex(int32_t configIndex, int32_t topK, int32_t hidde auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); - int32_t intermediateSizeFactor = (isGatedActivation(mActType) ? 2 : 1); auto const isValid = - mRunner.isValidConfigIndex(configIndex, numTokens, intermediateSizeFactor * intermediateSize, - hiddenSize, {}, numTokens, numExperts, maxNumCtasInBatchDim); + mRunner.isValidConfigIndex(configIndex, numTokens, 2 * intermediateSize, hiddenSize, {}, + numTokens, numExperts, maxNumCtasInBatchDim); return isValid; } @@ -345,13 +286,12 @@ std::vector Runner::getPassingConfigIndices() const { namespace Gemm2 { tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions getOptions( btg::Dtype dtypeAct, btg::Dtype dtypeWeights, btg::Dtype dtypeOut, int32_t tileTokensDim, - bool useDeepSeekFp8, bool useShuffledMatrix, batchedGemm::gemm::MatrixLayout weightLayout) { + bool useDeepSeekFp8, bool useShuffledMatrixA, batchedGemm::gemm::MatrixLayout weightLayout) { tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions options = { // Swap A and B dtypes because transposeMmaOutput is hardcoded to true .dtypeA = dtypeWeights, .dtypeB = dtypeAct, .dtypeC = dtypeOut, - .eltwiseActType = EltwiseActType::None, .deepSeekFp8 = useDeepSeekFp8, .fusedAct = false, .routeAct = false, @@ -359,13 +299,13 @@ tensorrt_llm::kernels::TrtllmGenBatchedGemmRunnerOptions getOptions( .transposeMmaOutput = true, .tileSize = tileTokensDim, .epilogueTileM = useDeepSeekFp8 ? 64 : 128, - .useShuffledMatrix = useShuffledMatrix, + .useShuffledMatrixA = useShuffledMatrixA, .weightLayout = weightLayout}; return options; } Runner::Runner(btg::Dtype dtypeAct, btg::Dtype dtypeWeights, btg::Dtype dtypeOut, - bool useDeepSeekFp8, int tileTokensDim, bool useShuffledMatrix, + bool useDeepSeekFp8, int tileTokensDim, bool useShuffledMatrixA, batchedGemm::gemm::MatrixLayout weightLayout) : mDtypeAct(dtypeAct), mDtypeWeights(dtypeWeights), @@ -373,7 +313,7 @@ Runner::Runner(btg::Dtype dtypeAct, btg::Dtype dtypeWeights, btg::Dtype dtypeOut mTileTokensDim(tileTokensDim), mRunner(tensorrt_llm::kernels::TrtllmGenBatchedGemmRunner( getOptions(dtypeAct, dtypeWeights, dtypeOut, tileTokensDim, useDeepSeekFp8, - useShuffledMatrix, weightLayout))) {} + useShuffledMatrixA, weightLayout))) {} void Runner::run(void* permutedHiddenState, void* permutedHiddenStateScale, void* weights, void* weightsScale, float* outputScalesScalar, float* ptrBias, void* output, @@ -433,12 +373,12 @@ std::vector Runner::getPassingConfigIndices() const { namespace MoE { Runner::Runner(btg::Dtype dtypeAct, btg::Dtype dtypeWeights, bool useDeepSeekFp8, - int32_t tileTokensDim, ActivationType activationType, bool useShuffledMatrix, + int32_t tileTokensDim, GatedActType gatedActType, bool useShuffledMatrixA, batchedGemm::gemm::MatrixLayout weightLayout) : mPermuteGemm1(PermuteGemm1::Runner(dtypeAct, dtypeWeights, useDeepSeekFp8, tileTokensDim, - activationType, useShuffledMatrix, weightLayout)), + gatedActType, useShuffledMatrixA, weightLayout)), mGemm2(Gemm2::Runner(dtypeAct, dtypeWeights, btg::Dtype::Bfloat16, useDeepSeekFp8, - tileTokensDim, useShuffledMatrix, weightLayout)) { + tileTokensDim, useShuffledMatrixA, weightLayout)) { auto const& gemm1PassingIndices = mPermuteGemm1.getPassingConfigIndices(); auto const& gemm2PassingIndices = mGemm2.getPassingConfigIndices(); @@ -455,9 +395,9 @@ Runner::Runner(btg::Dtype dtypeAct, btg::Dtype dtypeWeights, bool useDeepSeekFp8 } Runner::Runner(btg::Dtype dtypeElt, bool useDeepSeekFp8, int32_t tileTokensDim, - bool useShuffledMatrix, batchedGemm::gemm::MatrixLayout weightLayout) - : Runner(dtypeElt, dtypeElt, useDeepSeekFp8, tileTokensDim, ActivationType::Swiglu, - useShuffledMatrix, weightLayout) {} + bool useShuffledMatrixA, batchedGemm::gemm::MatrixLayout weightLayout) + : Runner(dtypeElt, dtypeElt, useDeepSeekFp8, tileTokensDim, GatedActType::SwiGlu, + useShuffledMatrixA, weightLayout) {} void Runner::setOpsData(MoERunnerArgs const& args, MoEWorkspace const& workspace, moe::dev::convertsf::Data& convertSfData, @@ -480,8 +420,7 @@ void Runner::setOpsData(MoERunnerArgs const& args, MoEWorkspace const& workspace activationData.outPtr = workspace.activation_output; activationData.inDqSfsPtr = workspace.gemm1_output_scale; activationData.outDqSfsPtr = workspace.activation_output_scale; - activationData.innerDim = - args.intermediate_size * (isGatedActivation(args.activation_type) ? 2 : 1); + activationData.innerDim = args.intermediate_size * 2; activationData.topK = args.top_k; activationData.numTokens = args.num_tokens; activationData.expandedIdxToPermutedIdx = workspace.expanded_idx_to_permuted_idx; diff --git a/flashinfer/__init__.py b/flashinfer/__init__.py index c78ceb215b..c22b4a0a55 100644 --- a/flashinfer/__init__.py +++ b/flashinfer/__init__.py @@ -75,8 +75,8 @@ ) from .fp8_quantization import mxfp8_dequantize_host, mxfp8_quantize from .fused_moe import ( - ActivationType, RoutingMethodType, + GatedActType, cutlass_fused_moe, reorder_rows_for_gated_act_gemm, trtllm_fp4_block_scale_moe, diff --git a/flashinfer/fused_moe/__init__.py b/flashinfer/fused_moe/__init__.py index a077ea82d5..f7886fe400 100644 --- a/flashinfer/fused_moe/__init__.py +++ b/flashinfer/fused_moe/__init__.py @@ -15,8 +15,8 @@ """ from .core import ( - ActivationType, RoutingMethodType, + GatedActType, WeightLayout, convert_to_block_layout, cutlass_fused_moe, @@ -40,8 +40,8 @@ ) __all__ = [ - "ActivationType", "RoutingMethodType", + "GatedActType", "WeightLayout", "convert_to_block_layout", "cutlass_fused_moe", diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 2821ce829a..0e9d643b4c 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -173,6 +173,15 @@ class WeightLayout(IntEnum): BlockMajorK = 2 +# The type of gated activation function +# Please keep this in sync with the counterpart defined in include/flashinfer/trtllm/fused_moe/runner.h +class GatedActType(IntEnum): + # SwiGlu + SwiGlu = 0 + # GeGlu + GeGlu = 1 + + @functools.cache def is_trtllm_moe_supported( dtype_weights: DtypeTrtllmGen, @@ -212,16 +221,12 @@ def _maybe_get_cached_w3_w1_permute_indices( dst_w3_w1_weight: torch.Tensor, epilogue_tile_m: int, num_elts_per_sf: Union[None, int] = None, - is_gated_act_gemm: bool = True, ) -> torch.Tensor: # Create a unique cache key (weight_type, weight_shape) cache_key = ("w3_w1", dst_w3_w1_weight.shape) if cache_key not in _cache_permute_indices: # Get permute indices and chain them together - if is_gated_act_gemm: - permute0 = get_reorder_rows_for_gated_act_gemm_row_indices(dst_w3_w1_weight) - else: - permute0 = torch.arange(dst_w3_w1_weight.shape[0], dtype=torch.long) + permute0 = get_reorder_rows_for_gated_act_gemm_row_indices(dst_w3_w1_weight) if num_elts_per_sf is None: permute1 = get_shuffle_matrix_a_row_indices( dst_w3_w1_weight, epilogue_tile_m=epilogue_tile_m @@ -989,7 +994,7 @@ def __init__( use_deepseek_fp8: bool, hidden_size: int, intermediate_size: int, - activation_type: int = ActivationType.Swiglu, + gated_act_type: int = GatedActType.SwiGlu, use_shuffled_weight: bool = False, weight_layout: int = WeightLayout.MajorK, use_packed_weights: bool = False, @@ -1002,7 +1007,7 @@ def __init__( self.top_k = top_k self.hidden_size = hidden_size self.intermediate_size = intermediate_size - self.activation_type = ActivationType(activation_type) + self.gated_act_type = GatedActType(gated_act_type) self.use_shuffled_weight = use_shuffled_weight self.weight_layout = WeightLayout(weight_layout) self.use_packed_weights = use_packed_weights @@ -1030,7 +1035,7 @@ def get_valid_tactics( self.hidden_size, self.intermediate_size, self.num_local_experts, - self.activation_type, + self.gated_act_type, self.use_shuffled_weight, self.weight_layout, num_tokens, @@ -1174,7 +1179,6 @@ def forward( kwargs["routing_method_type"], kwargs["enable_pdl"], [-1, -1] if tactic == -1 else tactic, - self.activation_type, ) elif ( self.dtype_act == DtypeTrtllmGen.Bfloat16 @@ -1235,7 +1239,7 @@ def forward( kwargs["routing_method_type"], kwargs["enable_pdl"], kwargs["do_finalize"], - self.activation_type, + self.gated_act_type, output, [-1, -1] if tactic == -1 else tactic, ) @@ -1324,7 +1328,7 @@ def trtllm_bf16_moe_op( intermediate_size=intermediate_size, weight_layout=weight_layout, use_shuffled_weight=use_shuffled_weight, - activation_type=ActivationType.Swiglu, # Default for BF16 + gated_act_type=GatedActType.SwiGlu, # Default for BF16 ) inputs = [output, routing_logits, topk_ids, expert_weights, hidden_states] @@ -1422,7 +1426,6 @@ def trtllm_fp8_per_tensor_scale_moe_op( routing_method_type: int = 0, enable_pdl: Optional[bool] = None, tune_max_num_tokens: int = 8192, - activation_type: ActivationType = ActivationType.Swiglu, ) -> torch.Tensor: if enable_pdl is None: enable_pdl = device_support_pdl(hidden_states.device) @@ -1457,7 +1460,6 @@ def trtllm_fp8_per_tensor_scale_moe_op( intermediate_size=intermediate_size, weight_layout=WeightLayout.MajorK, use_shuffled_weight=True, - activation_type=activation_type, ) inputs = [output, routing_logits, topk_ids, expert_weights, hidden_states] @@ -1482,7 +1484,6 @@ def trtllm_fp8_per_tensor_scale_moe_op( use_routing_scales_on_input=use_routing_scales_on_input, routing_method_type=routing_method_type, enable_pdl=enable_pdl, - activation_type=activation_type.value, ) # Call the C++ function result = moe_op.trtllm_fp8_per_tensor_scale_moe( @@ -1507,7 +1508,6 @@ def trtllm_fp8_per_tensor_scale_moe_op( routing_method_type, enable_pdl, [-1, -1] if tactic == -1 else tactic, - activation_type.value, ) return result @@ -1533,7 +1533,6 @@ def _fake_trtllm_fp8_per_tensor_scale_moe( use_routing_scales_on_input: bool, routing_method_type: int = 0, enable_pdl: Optional[bool] = None, - activation_type: int = ActivationType.Swiglu.value, ): seq_len = hidden_states.shape[0] hidden_size = hidden_states.shape[1] @@ -1752,7 +1751,7 @@ def trtllm_fp4_block_scale_moe_op( routing_method_type: int, do_finalize: bool, enable_pdl: Optional[bool] = None, - activation_type: int = ActivationType.Swiglu.value, + gated_act_type: int = 0, output: Optional[torch.Tensor] = None, tune_max_num_tokens: int = 8192, ) -> List[torch.Tensor]: @@ -1812,7 +1811,7 @@ def trtllm_fp4_block_scale_moe_op( use_deepseek_fp8=False, hidden_size=hidden_size, intermediate_size=intermediate_size, - activation_type=activation_type, + gated_act_type=gated_act_type, weight_layout=WeightLayout.MajorK, use_shuffled_weight=True, ) @@ -1859,7 +1858,7 @@ def trtllm_fp4_block_scale_moe_op( routing_method_type=routing_method_type, enable_pdl=enable_pdl, do_finalize=do_finalize, - activation_type=activation_type, + gated_act_type=gated_act_type, ) # Call the C++ function for block scale MoE @@ -1893,7 +1892,7 @@ def trtllm_fp4_block_scale_moe_op( routing_method_type, do_finalize, enable_pdl, - activation_type, + gated_act_type, output, [-1, -1] if tactic == -1 else tactic, ) @@ -1938,7 +1937,7 @@ def _fake_trtllm_fp4_block_scale_moe( routing_method_type: int, do_finalize: bool, enable_pdl: bool, - activation_type: int, + gated_act_type: int, output: Optional[torch.Tensor], tune_max_num_tokens: int, ): @@ -2010,7 +2009,7 @@ def trtllm_mxint4_block_scale_moe_op( use_deepseek_fp8=False, hidden_size=hidden_size, intermediate_size=intermediate_size, - activation_type=ActivationType.Swiglu, + gated_act_type=GatedActType.SwiGlu, weight_layout=WeightLayout.BlockMajorK, use_shuffled_weight=True, ) @@ -2217,7 +2216,6 @@ def trtllm_fp8_per_tensor_scale_moe( routing_method_type: int = 0, enable_pdl: Optional[bool] = None, tune_max_num_tokens: int = 8192, - activation_type: int = ActivationType.Swiglu.value, ) -> torch.Tensor: """FP8 per tensor scale MoE operation. @@ -2242,15 +2240,6 @@ def trtllm_fp8_per_tensor_scale_moe( routing_method_type: Type of routing method to use (default: 0) enable_pdl: Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90. tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 8192) - activation_type (int): Type of activation function (default: 3 - Swiglu) - - 0: Gelu - - 1: Relu - - 2: Silu - - 3: Swiglu - - 4: Geglu - - 5: SwigluBias - - 6: Relu2 - - 7: Identity Returns: torch.Tensor: Output tensor of shape [seq_len, hidden_size] @@ -2276,7 +2265,6 @@ def trtllm_fp8_per_tensor_scale_moe( routing_method_type, enable_pdl, tune_max_num_tokens, - activation_type, ) @@ -2478,7 +2466,7 @@ def trtllm_fp4_block_scale_moe( routing_method_type: int = 0, do_finalize: bool = True, enable_pdl: Optional[bool] = None, - activation_type: int = ActivationType.Swiglu.value, + gated_act_type: int = 0, output: Optional[torch.Tensor] = None, tune_max_num_tokens: int = 8192, ) -> List[torch.Tensor]: @@ -2533,15 +2521,9 @@ def trtllm_fp4_block_scale_moe( - 4: RenormalizeNaive (Softmax -> TopK -> Renormalize) do_finalize (bool): Whether to finalize the output (default: False) enable_pdl (Optional[bool]): Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90. - activation_type (int): Type of activation function (default: 3 - Swiglu) - - 0: Gelu - - 1: Relu - - 2: Silu - - 3: Swiglu - - 4: Geglu - - 5: SwigluBias - - 6: Relu2 - - 7: Identity + gated_act_type (int): Type of gated activation function (default: 0) + - 0: SwiGlu + - 1: GeGlu tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 8192) output (Optional[torch.Tensor]): shape [seq_len, hidden_size] Optional inplace output tensor. @@ -2579,7 +2561,7 @@ def trtllm_fp4_block_scale_moe( routing_method_type, do_finalize, enable_pdl, - activation_type, + gated_act_type, output, tune_max_num_tokens, ) @@ -2614,7 +2596,7 @@ def trtllm_fp4_block_scale_routed_moe( routing_method_type: int = 0, do_finalize: bool = True, enable_pdl: Optional[bool] = None, - activation_type: int = ActivationType.Swiglu.value, + gated_act_type: int = 0, output: Optional[torch.Tensor] = None, tune_max_num_tokens: int = 8192, ) -> List[torch.Tensor]: @@ -2670,15 +2652,9 @@ def trtllm_fp4_block_scale_routed_moe( - 3: Llama4 (Top1 -> Sigmoid) - 4: RenormalizeNaive (Softmax -> TopK -> Renormalize) do_finalize (bool): Whether to finalize the output (default: False) - activation_type (int): Type of activation function (default: 3 - Swiglu) - - 0: Gelu - - 1: Relu - - 2: Silu - - 3: Swiglu - - 4: Geglu - - 5: SwigluBias - - 6: Relu2 - - 7: Identity + gated_act_type (int): Type of gated activation function (default: 0) + - 0: SwiGlu + - 1: GeGlu tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 8192) output (Optional[torch.Tensor]): shape [seq_len, hidden_size] Optional inplace output tensor. @@ -2717,7 +2693,7 @@ def trtllm_fp4_block_scale_routed_moe( routing_method_type, do_finalize, enable_pdl, - activation_type, + gated_act_type, output, tune_max_num_tokens, ) diff --git a/include/flashinfer/trtllm/batched_gemm/KernelRunner.h b/include/flashinfer/trtllm/batched_gemm/KernelRunner.h index 54cd824c0e..970f1ae494 100644 --- a/include/flashinfer/trtllm/batched_gemm/KernelRunner.h +++ b/include/flashinfer/trtllm/batched_gemm/KernelRunner.h @@ -47,27 +47,11 @@ enum class ActType { GeGlu, }; -// Type of the element-wise activation to apply after the Gemm -enum class EltwiseActType { - None = 0, - // Gelu is defined as the following operation: - // act = x0 * phi(x0) - // where x0 is the output of the Gemm - // phi is the CDF of standard normal distribution approximated by - // phi(x) = 0.5 * (1 + tanh(0.7978845608028654 * (x + 0.044715 * x * x * x))) - Gelu, - // Relu2 (also known as squared Relu) is defined as the following operation: - // act = relu(x0) ^ 2 - // where x0 is the output of the Gemm. - Relu2, -}; - struct TrtllmGenBatchedGemmRunnerOptions { batchedGemm::trtllm::gen::Dtype dtypeA; batchedGemm::trtllm::gen::Dtype dtypeB; batchedGemm::trtllm::gen::Dtype dtypeC; ActType actType{ActType::SwiGlu}; - EltwiseActType eltwiseActType{EltwiseActType::None}; bool deepSeekFp8{false}; bool fusedAct{false}; bool routeAct{false}; @@ -75,7 +59,7 @@ struct TrtllmGenBatchedGemmRunnerOptions { bool transposeMmaOutput{false}; int32_t tileSize{8}; int32_t epilogueTileM{128}; - bool useShuffledMatrix{false}; + bool useShuffledMatrixA{false}; batchedGemm::gemm::MatrixLayout weightLayout{batchedGemm::gemm::MatrixLayout::MajorK}; }; diff --git a/include/flashinfer/trtllm/fused_moe/DevKernel.h b/include/flashinfer/trtllm/fused_moe/DevKernel.h index 560063c023..23abb87a7b 100644 --- a/include/flashinfer/trtllm/fused_moe/DevKernel.h +++ b/include/flashinfer/trtllm/fused_moe/DevKernel.h @@ -169,65 +169,56 @@ namespace moe::dev { FLASHINFER_WARN("Unsupported dtypeExpW"); \ } -#define LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, \ - smemSize, stream, extraFlag, numExperts, \ - numTopExperts) \ - if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Fp32 && \ - data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_TILEN(data, coopLaunch, \ - LAUNCH_ESC(float, float, float, numExperts, numTopExperts, extraFlag), kernel, \ - numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Fp32 && \ - data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_TILEN(data, coopLaunch, \ - LAUNCH_ESC(float, float, __nv_bfloat16, numExperts, numTopExperts, extraFlag), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ - data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_TILEN(data, coopLaunch, \ - LAUNCH_ESC(float, __nv_bfloat16, float, numExperts, numTopExperts, extraFlag), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ - data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_TILEN( \ - data, coopLaunch, \ - LAUNCH_ESC(float, __nv_bfloat16, __nv_bfloat16, numExperts, numTopExperts, extraFlag), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Fp32 && \ - data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_TILEN(data, coopLaunch, \ - LAUNCH_ESC(__nv_bfloat16, float, float, numExperts, numTopExperts, extraFlag), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Fp32 && \ - data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_TILEN( \ - data, coopLaunch, \ - LAUNCH_ESC(__nv_bfloat16, float, __nv_bfloat16, numExperts, numTopExperts, extraFlag), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ - data.mDtypeExpW == tg::Dtype::Fp32) { \ - LAUNCH_TILEN( \ - data, coopLaunch, \ - LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, float, numExperts, numTopExperts, extraFlag), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ - data.mDtypeExpW == tg::Dtype::Bfloat16) { \ - LAUNCH_TILEN(data, coopLaunch, \ - LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, numExperts, \ - numTopExperts, extraFlag), \ - kernel, numBlocks, numThreads, smemSize, stream); \ - } else { \ - FLASHINFER_WARN("Unsupported dtypeExpW"); \ +#define LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, \ + smemSize, stream, extraFlag, numExperts) \ + if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Fp32 && \ + data.mDtypeExpW == tg::Dtype::Fp32) { \ + LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, float, numExperts, extraFlag), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Fp32 && \ + data.mDtypeExpW == tg::Dtype::Bfloat16) { \ + LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, float, __nv_bfloat16, numExperts, extraFlag), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ + data.mDtypeExpW == tg::Dtype::Fp32) { \ + LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(float, __nv_bfloat16, float, numExperts, extraFlag), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Fp32 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ + data.mDtypeExpW == tg::Dtype::Bfloat16) { \ + LAUNCH_TILEN(data, coopLaunch, \ + LAUNCH_ESC(float, __nv_bfloat16, __nv_bfloat16, numExperts, extraFlag), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Fp32 && \ + data.mDtypeExpW == tg::Dtype::Fp32) { \ + LAUNCH_TILEN(data, coopLaunch, LAUNCH_ESC(__nv_bfloat16, float, float, numExperts, extraFlag), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Fp32 && \ + data.mDtypeExpW == tg::Dtype::Bfloat16) { \ + LAUNCH_TILEN(data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, float, __nv_bfloat16, numExperts, extraFlag), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ + data.mDtypeExpW == tg::Dtype::Fp32) { \ + LAUNCH_TILEN(data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, float, numExperts, extraFlag), kernel, \ + numBlocks, numThreads, smemSize, stream); \ + } else if (data.mDtypeScore == tg::Dtype::Bfloat16 && data.mDtypeBias == tg::Dtype::Bfloat16 && \ + data.mDtypeExpW == tg::Dtype::Bfloat16) { \ + LAUNCH_TILEN(data, coopLaunch, \ + LAUNCH_ESC(__nv_bfloat16, __nv_bfloat16, __nv_bfloat16, numExperts, extraFlag), \ + kernel, numBlocks, numThreads, smemSize, stream); \ + } else { \ + FLASHINFER_WARN("Unsupported dtypeExpW"); \ } -#define LAUNCH_ROUTING_DEEPSEEK_IMPL(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ - stream, extraFlag, numExperts, numTopExperts) \ - if (extraFlag) { \ - LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, \ - smemSize, stream, true, numExperts, numTopExperts); \ - } else { \ - LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, \ - smemSize, stream, false, numExperts, numTopExperts); \ +#define LAUNCH_ROUTING_DEEPSEEK_IMPL(data, coopLaunch, kernel, numBlocks, numThreads, smemSize, \ + stream, extraFlag, numExperts) \ + if (extraFlag) { \ + LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, \ + smemSize, stream, true, numExperts); \ + } else { \ + LAUNCH_ROUTING_DEEPSEEK_WITH_EXTRA_FLAG(data, coopLaunch, kernel, numBlocks, numThreads, \ + smemSize, stream, false, numExperts); \ } //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/flashinfer/trtllm/fused_moe/RoutingKernel.h b/include/flashinfer/trtllm/fused_moe/RoutingKernel.h index 709fb57c0f..cae6729368 100644 --- a/include/flashinfer/trtllm/fused_moe/RoutingKernel.h +++ b/include/flashinfer/trtllm/fused_moe/RoutingKernel.h @@ -176,15 +176,14 @@ struct Data : public DataBase { bool mUseRoutingSoftmax; }; -template +template struct KernelParams : public KernelParamsBase { using InputT = InputT_; using BiasT = BiasT_; using OutputT = OutputT_; static constexpr bool UseGroups = UseGroups_; - static constexpr int MaxNumTopExperts = MaxNumTopExperts_; PackedScoreIdx* mPtrTopKPacked = nullptr; diff --git a/include/flashinfer/trtllm/fused_moe/runner.h b/include/flashinfer/trtllm/fused_moe/runner.h index 46617e5dbd..3941a23249 100644 --- a/include/flashinfer/trtllm/fused_moe/runner.h +++ b/include/flashinfer/trtllm/fused_moe/runner.h @@ -136,48 +136,25 @@ class Runner { } // namespace Routing namespace MoE { -// The type of activation function +// The type of gated activation function // Please keep this in sync with the counterpart defined in flashinfer/flashinfer/fused_moe/core.py -enum class ActivationType : int64_t { - Gelu = 0, - Relu = 1, - Silu = 2, - Swiglu = 3, - Geglu = 4, - SwigluBias = 5, - Relu2 = 6, - Identity = 7, - InvalidType = 8, // Must be last +enum class GatedActType : int64_t { + // SwiGlu + SwiGlu = 0, + // GeGlu + GeGlu = 1, }; -inline std::string serializeActivationType(ActivationType activationType) { - switch (activationType) { - case ActivationType::Gelu: - return "Gelu"; - case ActivationType::Relu: - return "Relu"; - case ActivationType::Silu: - return "Silu"; - case ActivationType::Swiglu: - return "Swiglu"; - case ActivationType::Geglu: - return "Geglu"; - case ActivationType::SwigluBias: - return "SwigluBias"; - case ActivationType::Relu2: - return "Relu2"; - case ActivationType::Identity: - return "Identity"; +inline std::string serializeGatedActType(GatedActType gatedActType) { + switch (gatedActType) { + case GatedActType::SwiGlu: + return "SwiGlu"; + case GatedActType::GeGlu: + return "GeGlu"; default: - return "InvalidActivationType"; // TODO throw error + return "InvalidGatedActType"; // TODO throw error }; } - -inline bool isGatedActivation(ActivationType activationType) { - return activationType == ActivationType::Swiglu || activationType == ActivationType::Geglu || - activationType == ActivationType::SwigluBias; -} - } // namespace MoE namespace PermuteGemm1 { @@ -185,7 +162,7 @@ class Runner { public: explicit Runner(batchedGemm::trtllm::gen::Dtype dtypeAct, batchedGemm::trtllm::gen::Dtype dtypeWeights, bool useDeepSeekFp8, - int tileTokensDim, MoE::ActivationType activationType, bool useShuffledMatrix, + int tileTokensDim, MoE::GatedActType gatedActType, bool useShuffledMatrixA, batchedGemm::gemm::MatrixLayout weight_layout); size_t getWorkspaceSizeInBytes(int32_t topK, int32_t hiddenSize, int32_t intermediateSize, @@ -216,7 +193,6 @@ class Runner { batchedGemm::trtllm::gen::Dtype mDtypeWeights; int32_t mTileTokensDim; tensorrt_llm::kernels::TrtllmGenBatchedGemmRunner mRunner; - tensorrt_llm::kernels::trtllmgen_moe::MoE::ActivationType mActType; }; } // namespace PermuteGemm1 @@ -226,7 +202,7 @@ class Runner { explicit Runner(batchedGemm::trtllm::gen::Dtype dtypeAct, batchedGemm::trtllm::gen::Dtype dtypeWeights, batchedGemm::trtllm::gen::Dtype outputDtype, bool useDeepSeekFp8, - int tileTokensDim, bool useShuffledMatrix, + int tileTokensDim, bool useShuffledMatrixA, batchedGemm::gemm::MatrixLayout weight_layout); size_t getWorkspaceSizeInBytes(int32_t topK, int32_t hiddenSize, int32_t intermediateSize, @@ -283,8 +259,6 @@ struct MoERunnerArgs { float* gemm1_clamp_limit = nullptr; float* gemm2_bias = nullptr; - ActivationType activation_type = ActivationType::Swiglu; - int32_t num_tokens{0}; int32_t num_experts{0}; // Hidden dimension input of MoE block. It might be padded. @@ -382,10 +356,10 @@ class Runner { // FIXME: tileTokensDim is hardcoded for now Runner(batchedGemm::trtllm::gen::Dtype dtypeAct, batchedGemm::trtllm::gen::Dtype dtypeWeights, bool useDeepSeekFp8, int tileTokensDim = 8, - ActivationType activationType = ActivationType::Swiglu, bool useShuffledMatrix = false, + GatedActType gatedActType = GatedActType::SwiGlu, bool useShuffledMatrixA = false, batchedGemm::gemm::MatrixLayout weight_layout = batchedGemm::gemm::MatrixLayout::MajorK); Runner(batchedGemm::trtllm::gen::Dtype dtypeElt, bool useDeepSeekFp8, int tileTokensDim = 8, - bool useShuffledMatrix = false, + bool useShuffledMatrixA = false, batchedGemm::gemm::MatrixLayout weight_layout = batchedGemm::gemm::MatrixLayout::MajorK); void run(MoERunnerArgs const& args, MoEWorkspace const& workspace, int device, diff --git a/tests/moe/test_dpsk_fused_moe_fp8.py b/tests/moe/test_dpsk_fused_moe_fp8.py index cd44f2faf2..711e05f234 100644 --- a/tests/moe/test_dpsk_fused_moe_fp8.py +++ b/tests/moe/test_dpsk_fused_moe_fp8.py @@ -8,7 +8,7 @@ trtllm_fp8_block_scale_moe, ) from .utils import skip_checks, QuantMode -from flashinfer import ActivationType +from flashinfer import GatedActType def dequant_fp8_block_scaled( @@ -616,7 +616,7 @@ def __init__(self): moe_impl=moe_impl, routing_config=routing_config, weight_processing=weight_processing, - activation_type=ActivationType.Swiglu, + gated_act_type=GatedActType.SwiGlu, num_tokens=seq_len, hidden_size=7168, # DeepSeek-V3 hidden size intermediate_size=intermediate_size, diff --git a/tests/moe/test_trtllm_gen_fused_moe.py b/tests/moe/test_trtllm_gen_fused_moe.py index 4d2d56380e..89cbf84d4e 100644 --- a/tests/moe/test_trtllm_gen_fused_moe.py +++ b/tests/moe/test_trtllm_gen_fused_moe.py @@ -22,8 +22,8 @@ from torch.nn import functional as F from flashinfer import ( - ActivationType, RoutingMethodType, + GatedActType, e2m1_and_ufp8sf_scale_to_float, fp4_quantize, mxfp8_dequantize_host, @@ -46,7 +46,7 @@ get_w2_permute_indices_with_cache, _maybe_get_cached_w3_w1_permute_indices, ) -from .utils import is_gated_activation, skip_checks, QuantMode +from .utils import skip_checks, QuantMode # Max num tokens to tune for trtllm-gen fused moe @@ -209,7 +209,7 @@ def _run_moe_computation(self, runtime_args): local_num_experts=self.config["num_experts"], routed_scaling_factor=self.config["routed_scaling"], routing_method_type=self.config["routing_method_type"], - activation_type=self.config["activation_type"], + gated_act_type=self.config["gated_act_type"], do_finalize=True, tune_max_num_tokens=TUNE_MAX_NUM_TOKENS, ) @@ -227,12 +227,6 @@ class Moe(ABC): def __init__(self): self.name = self.__class__.__name__ - @property - @abstractmethod - def quant_mode(self) -> QuantMode: - """Get the quantization mode of this MoE implementation.""" - pass - @abstractmethod def quantize_weights(self, gemm1_weights, gemm2_weights, hidden_states_sample): """Quantize static weights and compute global scale factors (done offline).""" @@ -311,17 +305,13 @@ class FP4Moe(Moe): def __init__(self, quant_mode: QuantMode): super().__init__() - self._quant_mode = quant_mode + self.quant_mode = quant_mode self.is_mxfp4 = ( quant_mode == QuantMode.FP4_MXFP4_MXFP8 or quant_mode == QuantMode.FP4_MXFP4_Bf16 ) self.sf_vec_size = 32 if self.is_mxfp4 else 16 - @property - def quant_mode(self) -> QuantMode: - return self._quant_mode - def quantize_weights(self, gemm1_weights, gemm2_weights, hidden_states_sample): """Quantize weights to FP4 format and compute global scale factors.""" num_experts = gemm1_weights.shape[0] @@ -418,16 +408,13 @@ def prepare_static_weights_for_kernel( ) # Convert quantized weights to proper formats - intermediate_size_factor = 2 if is_gated_activation(args.activation_type) else 1 gemm1_weights_fp4 = args.gemm1_weights.view(torch.float8_e4m3fn).reshape( - num_experts, intermediate_size_factor * intermediate_size, hidden_size // 2 + num_experts, 2 * intermediate_size, hidden_size // 2 ) # packed fp4 gemm1_scales_linear_fp4 = gemm1_scales_linear_fp4_bytes.view( torch.float8_e4m3fn ).reshape( - num_experts, - intermediate_size_factor * intermediate_size, - hidden_size // self.sf_vec_size, + num_experts, 2 * intermediate_size, hidden_size // self.sf_vec_size ) # fp8 scaling factors gemm2_weights_fp4 = args.gemm2_weights.view(torch.float8_e4m3fn).reshape( @@ -453,7 +440,6 @@ def prepare_static_weights_for_kernel( self._cache_permute_indices, gemm1_weights_fp4[i].view(torch.uint8), epilogue_tile_m, - is_gated_act_gemm=is_gated_activation(args.activation_type), ) gemm1_weights_fp4_shuffled.append( gemm1_weights_fp4[i] @@ -466,7 +452,6 @@ def prepare_static_weights_for_kernel( gemm1_scales_linear_fp4[i].view(torch.uint8), epilogue_tile_m, num_elts_per_sf=16, - is_gated_act_gemm=is_gated_activation(args.activation_type), ) gemm1_scales_fp4_shuffled.append( block_scale_interleave( @@ -511,9 +496,7 @@ def prepare_static_weights_for_kernel( torch.stack(gemm1_scales_fp4_shuffled) .view(torch.float8_e4m3fn) .reshape( - num_experts, - intermediate_size_factor * intermediate_size, - hidden_size // self.sf_vec_size, + num_experts, 2 * intermediate_size, hidden_size // self.sf_vec_size ) ) @@ -525,16 +508,11 @@ def prepare_static_weights_for_kernel( ) # Calculate scaling factors that depend on weights - if is_gated_activation(args.activation_type): - scale_c_fc1 = ( - args_dequant.c_global_sf - * (1.0 / args.gemm1_scales_global) - * (1.0 / args.hidden_states_scale_global) - ) - else: - scale_c_fc1 = torch.full_like( - args.gemm1_scales_global, args_dequant.c_global_sf - ) + scale_c_fc1 = ( + args_dequant.c_global_sf + * (1.0 / args.gemm1_scales_global) + * (1.0 / args.hidden_states_scale_global) + ) scale_gate_fc1 = (1.0 / args.gemm1_scales_global) * ( 1.0 / args.hidden_states_scale_global ) @@ -565,7 +543,7 @@ def call_moe( top_k_groups = kwargs["top_k_groups"] intermediate_size = kwargs["intermediate_size"] routed_scaling = kwargs["routed_scaling"] - activation_type = kwargs["activation_type"] + gated_act_type = kwargs["gated_act_type"] routing_method_type = kwargs["routing_method_type"] enable_autotune = kwargs.get("enable_autotune", True) @@ -578,7 +556,7 @@ def call_moe( "top_k_groups": top_k_groups, "intermediate_size": intermediate_size, "routed_scaling": routed_scaling, - "activation_type": activation_type, + "gated_act_type": gated_act_type, "routing_method_type": routing_method_type, "enable_autotune": enable_autotune, } @@ -632,10 +610,6 @@ def mxint4_quantize( class MxInt4BlockScaleMoe(Moe): """MxInt4 MoE implementation with block scaling (DeepSeek style).""" - @property - def quant_mode(self) -> QuantMode: - return QuantMode.MXINT4_BF16_BF16 - def quantize_weights(self, gemm1_weights, gemm2_weights, hidden_states_sample): """Quantize weights to MxInt4 with block scaling.""" num_experts = gemm1_weights.shape[0] @@ -830,10 +804,6 @@ def get_tolerances(self): class FP8BlockScaleMoe(Moe): """FP8 MoE implementation with block scaling (DeepSeek style).""" - @property - def quant_mode(self) -> QuantMode: - return QuantMode.FP8_BLOCK_SCALE - def quantize_weights(self, gemm1_weights, gemm2_weights, hidden_states_sample): """Quantize weights to FP8 with block scaling.""" num_experts = gemm1_weights.shape[0] @@ -1055,10 +1025,6 @@ def get_tolerances(self): class FP8PerTensorMoe(Moe): """FP8 MoE implementation with per-tensor scaling (Llama4 style).""" - @property - def quant_mode(self) -> QuantMode: - return QuantMode.FP8_PER_TENSOR - def quantize_weights(self, gemm1_weights, gemm2_weights, hidden_states_sample): """Quantize weights to FP8 per-tensor and compute global scale factors.""" # Compute global scale factor for hidden states (offline calibration) @@ -1114,20 +1080,14 @@ def prepare_static_weights_for_kernel( # Reorder rows of W1 for fused gated activation gemm1_weights_fp8_interleaved = [] for i in range(num_experts): - if is_gated_activation(args.activation_type): - weights = reorder_rows_for_gated_act_gemm(args.gemm1_weights[i].clone()) - else: - weights = args.gemm1_weights[i].clone() - gemm1_weights_fp8_interleaved.append(weights) + gemm1_weights_fp8_interleaved.append( + reorder_rows_for_gated_act_gemm(args.gemm1_weights[i].clone()) + ) # Stack weights and scales for all experts gemm1_weights_fp8_interleaved = torch.stack( gemm1_weights_fp8_interleaved - ).reshape( - num_experts, - (2 if is_gated_activation(args.activation_type) else 1) * intermediate_size, - hidden_size, - ) + ).reshape(num_experts, 2 * intermediate_size, hidden_size) # Shuffle weights and scaling factors for transposed mma output gemm1_weights_fp8_shuffled = [] @@ -1154,16 +1114,11 @@ def prepare_static_weights_for_kernel( ) # Calculate scaling factors that depend on weights - if is_gated_activation(args.activation_type): - scale_c_fc1 = ( - args_dequant.c_global_sf - * (1.0 / args.gemm1_scales_global) - * (1.0 / args.hidden_states_scale_global) - ) - else: - scale_c_fc1 = torch.full_like( - args.gemm1_scales_global, args_dequant.c_global_sf - ) + scale_c_fc1 = ( + args_dequant.c_global_sf + * (1.0 / args.gemm1_scales_global) + * (1.0 / args.hidden_states_scale_global) + ) scale_gate_fc1 = (1.0 / args.gemm1_scales_global) * ( 1.0 / args.hidden_states_scale_global ) @@ -1193,7 +1148,6 @@ def call_moe( routed_scaling = kwargs["routed_scaling"] routing_method_type = kwargs["routing_method_type"] enable_autotune = kwargs.get("enable_autotune", True) - activation_type = kwargs["activation_type"] # Quantize to FP8 per-tensor using pre-computed global scale factor hidden_states_fp8, _ = quant_fp8_per_tensor( @@ -1227,7 +1181,6 @@ def call_moe( == RoutingMethodType.Llama4, # Use_routing_scales_on_input routing_method_type, tune_max_num_tokens=TUNE_MAX_NUM_TOKENS, - activation_type=activation_type, ) return output.to(torch.float) @@ -1249,10 +1202,6 @@ def get_tolerances(self): class BF16Moe(Moe): """BF16 MoE implementation.""" - @property - def quant_mode(self) -> QuantMode: - return QuantMode.BF16 - def quantize_weights(self, gemm1_weights, gemm2_weights, hidden_states_sample): """No scaling for weights.""" return { @@ -1434,7 +1383,7 @@ def __init__( gemm2_scales_global, permute_info, use_routing_scales_on_input, - activation_type, + gated_act_type, ): self.num_tokens = num_tokens self.num_experts = num_experts @@ -1454,7 +1403,7 @@ def __init__( self.gemm2_scales_global = gemm2_scales_global self.permute_info = permute_info self.use_routing_scales_on_input = use_routing_scales_on_input - self.activation_type = activation_type + self.gated_act_type = gated_act_type class moe_args_dequant: @@ -1474,7 +1423,7 @@ def __init__( gemm2_weights, permute_info, use_routing_scales_on_input, - activation_type, + gated_act_type, hidden_states_scale=None, ): self.num_tokens = num_tokens @@ -1489,7 +1438,7 @@ def __init__( self.gemm2_weights = gemm2_weights self.permute_info = permute_info self.use_routing_scales_on_input = use_routing_scales_on_input - self.activation_type = activation_type + self.gated_act_type = gated_act_type self.hidden_states_scale = hidden_states_scale @@ -1913,11 +1862,7 @@ def run_moe_dequant(args, quant_mode: QuantMode): # Gemm1 gemm1_output = torch.full( - ( - total_num_padded_tokens, - (2 if is_gated_activation(args.activation_type) else 1) - * args.intermediate_size, - ), + (total_num_padded_tokens, 2 * args.intermediate_size), float("nan"), device="cuda", ).to(torch.float) @@ -1952,13 +1897,12 @@ def run_moe_dequant(args, quant_mode: QuantMode): (total_num_padded_tokens, args.intermediate_size), float("nan"), device="cuda" ).to(torch.float) - activation_type = args.activation_type - activation_type_to_func = { - ActivationType.Swiglu: F.silu, - ActivationType.Geglu: F.gelu, - ActivationType.Relu2: lambda x: F.relu(x) ** 2, + gated_act_type = args.gated_act_type + gated_act_type_to_func = { + 0: F.silu, + 1: F.gelu, } - activation_func = activation_type_to_func[activation_type] + gated_act_func = gated_act_type_to_func[gated_act_type] i = 0 for expert_idx in range(args.num_experts): @@ -1966,13 +1910,9 @@ def run_moe_dequant(args, quant_mode: QuantMode): if my_num_tokens == 0: continue my_a = gemm1_output[i : i + my_num_tokens] - if is_gated_activation(args.activation_type): - my_x1 = my_a[:, : args.intermediate_size] - my_x2 = my_a[:, args.intermediate_size :] - activation_output[i : i + my_num_tokens] = activation_func(my_x2) * my_x1 - else: - my_x1 = my_a[:, : args.intermediate_size] - activation_output[i : i + my_num_tokens] = activation_func(my_x1) + my_x1 = my_a[:, : args.intermediate_size] + my_x2 = my_a[:, args.intermediate_size :] + activation_output[i : i + my_num_tokens] = gated_act_func(my_x2) * my_x1 i += my_num_tokens i = (i + args.padding - 1) // args.padding * args.padding @@ -2099,7 +2039,7 @@ def run_moe_reference_fp4(args, quant_mode: QuantMode): gemm2_weights_dequant, args.permute_info, args.use_routing_scales_on_input, - args.activation_type, + args.gated_act_type, ) return run_moe_dequant(args_dequant, quant_mode), args_dequant @@ -2164,7 +2104,7 @@ def dequant_reference_dsfp8(input, scale, transpose_scale, block_m, block_n): gemm2_weights_dequant, args.permute_info, args.use_routing_scales_on_input, - args.activation_type.value, + GatedActType.SwiGlu.value, # gated_act_type ) return run_moe_dequant(args_dequant, QuantMode.FP8_BLOCK_SCALE), args_dequant @@ -2201,7 +2141,7 @@ def run_moe_reference_per_tensor_scale_fp8(args): gemm2_weights_dequant, args.permute_info, args.use_routing_scales_on_input, - args.activation_type.value, + GatedActType.SwiGlu.value, # gated_act_type ) return run_moe_dequant(args_dequant, QuantMode.FP8_PER_TENSOR), args_dequant @@ -2232,7 +2172,7 @@ def run_moe_reference_bf16(args): gemm2_weights_dequant, args.permute_info, args.use_routing_scales_on_input, - args.activation_type.value, + GatedActType.SwiGlu.value, # gated_act_type ) return run_moe_dequant(args_dequant, QuantMode.BF16), args_dequant @@ -2283,7 +2223,7 @@ def dequantize(weights, scales): gemm2_weights_dequant, args.permute_info, args.use_routing_scales_on_input, - args.activation_type, + args.gated_act_type, ) return run_moe_dequant(args_dequant, QuantMode.MXINT4_BF16_BF16), args_dequant @@ -2317,7 +2257,7 @@ def _compute_moe_actual_unified(moe_impl, args_dequant, args, **kwargs): "routed_scaling": kwargs["routed_scaling"], "routing_method_type": kwargs["routing_method_type"], "do_finalize": True, - "activation_type": args.activation_type, + "gated_act_type": args.gated_act_type, "hidden_states_scale": args.hidden_states_scale, "hidden_states_quant": kwargs["hidden_states_quant"], "enable_autotune": kwargs.get("enable_autotune", True), @@ -2345,7 +2285,7 @@ def run_moe_test( moe_impl, routing_config, weight_processing, - activation_type, + gated_act_type, cache_permute_indices, zero_hidden_states=False, ): @@ -2354,7 +2294,7 @@ def run_moe_test( moe_impl, routing_config, weight_processing, - activation_type, + gated_act_type, num_tokens, hidden_size, intermediate_size, @@ -2379,7 +2319,7 @@ def run_moe_test( # Validation checks assert top_k <= num_experts - assert top_k <= 22 + assert top_k <= 10 if (top_k_groups is not None) and (n_groups is not None) and (n_groups > 0): assert top_k_groups <= 4 assert num_experts > n_groups @@ -2407,11 +2347,7 @@ def run_moe_test( (num_tokens, hidden_size), device="cuda", dtype=torch.bfloat16 ) gemm1_weights = torch.randn( - ( - num_experts, - (2 if is_gated_activation(activation_type) else 1) * intermediate_size, - hidden_size, - ), + (num_experts, 2 * intermediate_size, hidden_size), device="cuda", dtype=torch.bfloat16, ) @@ -2496,7 +2432,7 @@ def run_moe_test( quant_data["gemm2_scales_global"], permute_info, use_routing_scales_on_input, - activation_type, + gated_act_type, ) # Compute reference output @@ -2665,10 +2601,10 @@ def run_moe_test( ], ) @pytest.mark.parametrize( - "activation_type", + "gated_act_type", [ - pytest.param(ActivationType.Swiglu, id="Swiglu"), - pytest.param(ActivationType.Geglu, id="Geglu"), + pytest.param(GatedActType.SwiGlu, id="SwiGlu"), + pytest.param(GatedActType.GeGlu, id="GeGlu"), ], ) def test_renormalize_routing( @@ -2678,7 +2614,7 @@ def test_renormalize_routing( moe_impl, routing_config, weight_processing, - activation_type, + gated_act_type, cache_permute_indices, zero_hidden_states, ): @@ -2690,7 +2626,7 @@ def test_renormalize_routing( moe_impl, routing_config, weight_processing, - activation_type, + gated_act_type, cache_permute_indices, zero_hidden_states=zero_hidden_states, ) @@ -2699,11 +2635,10 @@ def test_renormalize_routing( # Test: DeepSeekV3 routing @pytest.mark.parametrize("num_tokens", [8, 768, 3072]) @pytest.mark.parametrize("hidden_size", [1024]) -@pytest.mark.parametrize("intermediate_size", [2688, 2048, 1024, 768, 512, 384]) +@pytest.mark.parametrize("intermediate_size", [2048, 1024, 768, 512, 384]) @pytest.mark.parametrize( "moe_impl", [ - pytest.param(FP8PerTensorMoe(), id="FP8_PerTensor"), pytest.param(FP8BlockScaleMoe(), id="FP8_Block"), pytest.param(FP4Moe(quant_mode=QuantMode.FP4_NVFP4_NVFP4), id="NvFP4xNvFP4"), pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_MXFP8), id="MxFP4xMxFP8"), @@ -2715,22 +2650,6 @@ def test_renormalize_routing( @pytest.mark.parametrize( "routing_config", [ - pytest.param( - { - "num_experts": 512, - "top_k": 22, - "padding": 8, - "n_groups": 1, - "top_k_groups": 1, - "routed_scaling": 2.5, - "has_routing_bias": True, - "routing_method_type": RoutingMethodType.DeepSeekV3, - "compatible_moe_impls": [FP8PerTensorMoe, FP4Moe, BF16Moe], - "compatible_intermediate_size": [1024, 2688], - "enable_autotune": True, - }, - id="nemotron_3", - ), pytest.param( { "num_experts": 384, @@ -2836,11 +2755,10 @@ def test_renormalize_routing( ], ) @pytest.mark.parametrize( - "activation_type", + "gated_act_type", [ - pytest.param(ActivationType.Swiglu, id="Swiglu"), - pytest.param(ActivationType.Geglu, id="Geglu"), - pytest.param(ActivationType.Relu2, id="Relu2"), + pytest.param(GatedActType.SwiGlu, id="SwiGlu"), + pytest.param(GatedActType.GeGlu, id="GeGlu"), ], ) def test_deepseekv3_routing( @@ -2850,7 +2768,7 @@ def test_deepseekv3_routing( moe_impl, routing_config, weight_processing, - activation_type, + gated_act_type, cache_permute_indices, ): """Test DeepSeekV3 routing configurations.""" @@ -2861,7 +2779,7 @@ def test_deepseekv3_routing( moe_impl, routing_config, weight_processing, - activation_type, + gated_act_type, cache_permute_indices, ) @@ -2912,10 +2830,10 @@ def test_deepseekv3_routing( ], ) @pytest.mark.parametrize( - "activation_type", + "gated_act_type", [ - pytest.param(ActivationType.Swiglu, id="Swiglu"), - pytest.param(ActivationType.Geglu, id="Geglu"), + pytest.param(GatedActType.SwiGlu, id="SwiGlu"), + pytest.param(GatedActType.GeGlu, id="GeGlu"), ], ) def test_topk_routing( @@ -2925,7 +2843,7 @@ def test_topk_routing( moe_impl, routing_config, weight_processing, - activation_type, + gated_act_type, cache_permute_indices, ): """Test TopK routing configuration.""" @@ -2936,7 +2854,7 @@ def test_topk_routing( moe_impl, routing_config, weight_processing, - activation_type, + gated_act_type, cache_permute_indices, ) @@ -2949,7 +2867,6 @@ def test_topk_routing( "moe_impl", [ pytest.param(FP8PerTensorMoe(), id="FP8_Tensor"), - pytest.param(FP4Moe(QuantMode.FP4_NVFP4_NVFP4), id="FP4"), ], ) @pytest.mark.parametrize( @@ -2965,7 +2882,7 @@ def test_topk_routing( "routed_scaling": 2.5, "has_routing_bias": True, "routing_method_type": RoutingMethodType.Llama4, - "compatible_moe_impls": [FP8PerTensorMoe, FP4Moe], + "compatible_moe_impls": [FP8PerTensorMoe], "compatible_intermediate_size": [1024, 2048], "enable_autotune": True, }, @@ -2987,10 +2904,9 @@ def test_topk_routing( ], ) @pytest.mark.parametrize( - "activation_type", + "gated_act_type", [ - pytest.param(ActivationType.Swiglu, id="Swiglu"), - pytest.param(ActivationType.Relu2, id="Relu2"), + pytest.param(GatedActType.SwiGlu, id="SwiGlu"), ], ) def test_llama4_routing( @@ -3000,7 +2916,7 @@ def test_llama4_routing( moe_impl, routing_config, weight_processing, - activation_type, + gated_act_type, cache_permute_indices, ): """Test Llama4 routing configuration with FP8 per-tensor.""" @@ -3011,6 +2927,6 @@ def test_llama4_routing( moe_impl, routing_config, weight_processing, - activation_type, + gated_act_type, cache_permute_indices, ) diff --git a/tests/moe/test_trtllm_gen_routed_fused_moe.py b/tests/moe/test_trtllm_gen_routed_fused_moe.py index a5272ceb36..7a47444081 100644 --- a/tests/moe/test_trtllm_gen_routed_fused_moe.py +++ b/tests/moe/test_trtllm_gen_routed_fused_moe.py @@ -20,7 +20,7 @@ from flashinfer import ( RoutingMethodType, - ActivationType, + GatedActType, fp4_quantize, mxfp8_quantize, ) @@ -185,7 +185,7 @@ def test_trtllm_gen_routed_fused_moe( routing_method_type.value, True, # do_finalize enable_pdl, - ActivationType.Swiglu.value, # act_type + GatedActType.SwiGlu.value, # gated_act_type None, )[0].to(torch.float) @@ -238,7 +238,7 @@ def test_trtllm_gen_routed_fused_moe( routing_method_type.value, True, # do_finalize enable_pdl, - ActivationType.Swiglu.value, # act_type + GatedActType.SwiGlu.value, # gated_act_type None, )[0].to(torch.float) diff --git a/tests/moe/utils.py b/tests/moe/utils.py index 8ff5cf82a2..19c01d5175 100644 --- a/tests/moe/utils.py +++ b/tests/moe/utils.py @@ -17,7 +17,7 @@ import pytest import torch from enum import IntEnum -from flashinfer import ActivationType, RoutingMethodType +from flashinfer import GatedActType, RoutingMethodType from flashinfer.utils import get_compute_capability @@ -33,25 +33,11 @@ class QuantMode(IntEnum): MXINT4_BF16_BF16 = 7 -NON_GATED_ACTIVATION_SUPPORTED_QUANT_MODES = [ - QuantMode.FP4_NVFP4_NVFP4, - QuantMode.FP8_PER_TENSOR, -] - - -def is_gated_activation(activation_type: ActivationType) -> bool: - return activation_type in [ - ActivationType.Swiglu, - ActivationType.Geglu, - ActivationType.SwigluBias, - ] - - def skip_checks( moe_impl, routing_config, weight_processing, - activation_type, + gated_act_type, num_tokens, hidden_size, intermediate_size, @@ -71,32 +57,24 @@ def skip_checks( pytest.skip("Skipping zero hidden states tests for non-FP8 Block Scale MoE.") # Skip incompatible combinations - if activation_type == ActivationType.Geglu and ( + if gated_act_type == GatedActType.GeGlu and ( not is_fp4_moe or moe_impl.quant_mode != QuantMode.FP4_NVFP4_NVFP4 or routing_config["routing_method_type"] != RoutingMethodType.TopK or num_tokens > 128 ): pytest.skip( - f"Incompatible: {moe_impl.name} + {activation_type} + {routing_config['routing_method_type']} + {num_tokens}" + f"Incompatible: {moe_impl.name} + {gated_act_type} + {routing_config['routing_method_type']} + {num_tokens}" ) - elif activation_type == ActivationType.Swiglu and ( + elif gated_act_type == GatedActType.SwiGlu and ( hidden_size > 1024 or intermediate_size > 1024 ): pytest.skip( - f"Skip for testing speed: {activation_type} + {hidden_size} + {intermediate_size}" - ) - - if ( - not is_gated_activation(activation_type) - and moe_impl.quant_mode not in NON_GATED_ACTIVATION_SUPPORTED_QUANT_MODES - ): - pytest.skip( - f"Incompatible: {moe_impl.name} + {activation_type=} + quant_mode={moe_impl.quant_mode}: non-gated activations only supported with these quant modes: {NON_GATED_ACTIVATION_SUPPORTED_QUANT_MODES}" + f"Skip for testing speed: {gated_act_type} + {hidden_size} + {intermediate_size}" ) # Skip large intermediate sizes for configurations with many experts - if routing_config["num_experts"] > 512 and intermediate_size > 512: + if routing_config["num_experts"] >= 512 and intermediate_size > 512: pytest.skip( f"Skipping for testing speed: intermediate_size={intermediate_size} with {routing_config['num_experts']} experts" ) @@ -114,7 +92,7 @@ def skip_checks( f"Incompatible: intermediate_size={intermediate_size} with {routing_config['routing_method_type'].name} routing ({routing_config['num_experts']} experts)" ) - if moe_impl.quant_mode == QuantMode.MXINT4_BF16_BF16 and ( + if type(moe_impl).__name__ == "MxInt4BlockScaleMoe" and ( intermediate_size % 256 != 0 or hidden_size % 256 != 0 ): pytest.skip(