diff --git a/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py b/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py index f9314c98fd..e2fc6017d5 100644 --- a/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py +++ b/benchmarks/bench_trtllm_gen_fused_moe_autotuner.py @@ -10,6 +10,7 @@ mxfp8_quantize, ) from flashinfer.fused_moe import ( + Fp8QuantizationType, trtllm_fp4_block_scale_moe, trtllm_mxint4_block_scale_moe, trtllm_fp8_per_tensor_scale_moe, @@ -53,7 +54,7 @@ def mxint4_quantize( def bench_trtllm_gen_fused_moe_autotuner_fp8( tune_max_num_tokens: Optional[int], - quant_mode: Literal["Fp8-Per-Tensor", "Fp8-Block"], + quant_mode: Literal["Fp8-Per-Tensor", "Fp8-Block", "MxFP8xMxFP8"], num_tokens: int, num_experts: int, hidden_size: int, @@ -79,29 +80,54 @@ def bench_trtllm_gen_fused_moe_autotuner_fp8( torch.bfloat16 ) - is_block_scale = quant_mode == "Fp8-Block" - if not is_block_scale: + is_block_scale = quant_mode != "Fp8-Per-Tensor" + if quant_mode == "Fp8-Per-Tensor": hidden_states, hidden_states_scale = fp8_quantize(hidden_states) w13, w13_scale = fp8_quantize(w13) w2, w2_scale = fp8_quantize(w2) else: - # block scale quantization is too slow, so we use per-tensor quantization for now - hidden_states, hidden_states_scale = fp8_quantize(hidden_states) - w13, w13_scale = fp8_quantize(w13) - w2, w2_scale = fp8_quantize(w2) - hidden_states_scale = torch.full( - (hidden_size // 128, num_tokens), hidden_states_scale.item(), device=device - ) - w13_scale = torch.full( - (num_experts, intermediate_size * 2 // 128, hidden_size // 128), - w13_scale.item(), - device=device, - ) - w2_scale = torch.full( - (num_experts, hidden_size // 128, intermediate_size // 128), - w2_scale.item(), - device=device, - ) + scale_vec_size = 128 if quant_mode == "Fp8-Block" else 32 + if quant_mode == "Fp8-Block": + # block scale quantization is too slow, so we use per-tensor quantization for now + hidden_states, hidden_states_scale = fp8_quantize( + hidden_states + ) # scalar quantization + w13, w13_scale = fp8_quantize(w13) # scalar quantization + w2, w2_scale = fp8_quantize(w2) # scalar quantization + hidden_states_scale = torch.full( + (hidden_size // scale_vec_size, num_tokens), + hidden_states_scale.item(), + device=device, + ) + w13_scale = torch.full( + ( + num_experts, + intermediate_size * 2 // scale_vec_size, + hidden_size // scale_vec_size, + ), + w13_scale.item(), + device=device, + ) + w2_scale = torch.full( + ( + num_experts, + hidden_size // scale_vec_size, + intermediate_size // scale_vec_size, + ), + w2_scale.item(), + device=device, + ) + else: # MxFP8xMxFP8 + hidden_states, hidden_states_scale = mxfp8_quantize(hidden_states, False) + w13, w13_scale = mxfp8_quantize(w13, True) + w2, w2_scale = mxfp8_quantize(w2, True) + hidden_states_scale = hidden_states_scale.view(torch.uint8).reshape( + num_tokens, -1 + ) + w13_scale = w13_scale.view(torch.uint8).reshape( + num_experts, intermediate_size * 2, -1 + ) + w2_scale = w2_scale.view(torch.uint8).reshape(num_experts, hidden_size, -1) output1_scale_scalar = ( torch.tensor([hidden_states_scale * w13_scale] * num_experts, device=device) @@ -136,12 +162,15 @@ def bench_trtllm_gen_fused_moe_autotuner_fp8( local_num_experts=num_experts, routed_scaling_factor=2.5, routing_method_type=RoutingMethodType.DeepSeekV3.value, - use_shuffled_weight=False, - weight_layout=WeightLayout.MajorK.value, # weight_layout + use_shuffled_weight=quant_mode == "MxFP8xMxFP8", + weight_layout=WeightLayout.MajorK.value, enable_pdl=enable_pdl, tune_max_num_tokens=num_tokens if tune_max_num_tokens is None else tune_max_num_tokens, + fp8_quantization_type=Fp8QuantizationType.DeepSeekFp8 + if quant_mode == "Fp8-Block" + else Fp8QuantizationType.MxFp8, ) else: fn = partial( @@ -468,6 +497,7 @@ def bench(do_autotune): "MxFP4xMxFP8", "MxFP4xBf16", "MxInt4xBf16", + "MxFP8xMxFP8", "Fp8-Per-Tensor", "Fp8-Block", ], @@ -505,7 +535,7 @@ def bench(do_autotune): args = parser.parse_args() fn = ( bench_trtllm_gen_fused_moe_autotuner_fp8 - if args.quant_mode in ["Fp8-Per-Tensor", "Fp8-Block"] + if args.quant_mode in ["Fp8-Per-Tensor", "Fp8-Block", "MxFP8xMxFP8"] else bench_trtllm_gen_fused_moe_autotuner_mxint4 if args.quant_mode == "MxInt4xBf16" else bench_trtllm_gen_fused_moe_autotuner_fp4 diff --git a/csrc/trtllm_batched_gemm_runner.cu b/csrc/trtllm_batched_gemm_runner.cu index f3eae5e9e3..b0c43ea751 100644 --- a/csrc/trtllm_batched_gemm_runner.cu +++ b/csrc/trtllm_batched_gemm_runner.cu @@ -135,7 +135,7 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner( size_t TrtllmGenBatchedGemmRunner::getWorkspaceSizeInBytes( int32_t m, int32_t n, int32_t k, std::vector const& batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim, int32_t configIndex) const { - BatchedGemmData gemmData; + BatchedGemmData gemmData{}; gemmData.mProblemDimensions.mNumBatches = numBatches; gemmData.mProblemDimensions.mNumTokens = numTokens; gemmData.mProblemDimensions.mBatchM = !mOptions.transposeMmaOutput; @@ -174,11 +174,12 @@ void TrtllmGenBatchedGemmRunner::run( CUstream stream, int device, int32_t configIndex, bool enable_pdl) { auto bmm = BatchedGemmInterface(); - BatchedGemmData gemmData; + BatchedGemmData gemmData{}; auto const configs = bmm.getBatchedGemmConfigs(); auto const& config = configs[configIndex]; + // printf("running config %d: %s\n", configIndex, config.mFunctionName); FLASHINFER_CHECK(numBatches > 0, "Batched GEMM requires numBatches > 0"); if (!mOptions.staticBatch) { @@ -327,7 +328,7 @@ std::vector TrtllmGenBatchedGemmRunner::getValidConfigIndices( int32_t multiProcessorCount = tensorrt_llm::common::getMultiProcessorCount(); - BatchedGemmData gemmData; + BatchedGemmData gemmData{}; // Dims gemmData.mProblemDimensions.mNumBatches = numBatches; gemmData.mProblemDimensions.mNumTokens = numTokens; @@ -436,7 +437,7 @@ bool TrtllmGenBatchedGemmRunner::isValidConfigIndex(int32_t configIndex, int32_t auto const bmm = BatchedGemmInterface(); auto const configs = bmm.getBatchedGemmConfigs(); - BatchedGemmData gemmData; + BatchedGemmData gemmData{}; // Dims gemmData.mProblemDimensions.mNumBatches = numBatches; gemmData.mProblemDimensions.mNumTokens = numTokens; @@ -451,12 +452,13 @@ bool TrtllmGenBatchedGemmRunner::isValidConfigIndex(int32_t configIndex, int32_t gemmData.mProblemDimensions.mRank = 0; gemmData.mProblemDimensions.mWorldSize = 1; gemmData.mProblemDimensions.mMaxNumCtasInTokenDim = maxNumCtasInBatchDim; + gemmData.mProblemDimensions.mValidM = gemmData.mProblemDimensions.mM; + gemmData.mProblemDimensions.mValidN = gemmData.mProblemDimensions.mN; + gemmData.mProblemDimensions.mValidK = gemmData.mProblemDimensions.mK; auto const& config = configs[configIndex]; - // FIXME: temporarily disable split-k as renormalize routing plus expert number 256 failed in - // trtllm-gen ac83afb - return bmm.isValidConfig(config, gemmData) && config.mOptions.mClusterDimZ == 1; + return bmm.isValidConfig(config, gemmData); } } // namespace kernels diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index bfa754c1c4..2aa21c65b3 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -41,6 +41,27 @@ using tensorrt_llm::kernels::trtllmgen_moe::Routing::RoutingMethodType; using tvm::ffi::Array; using tvm::ffi::Optional; +enum class Fp8QuantizationType { + NoneFp8, + DeepSeekFp8, + MxFp8, + PerTensorFp8, +}; + +inline std::string fp8QuantizationTypeToString(Fp8QuantizationType quantization_type) { + switch (quantization_type) { + default: + case Fp8QuantizationType::NoneFp8: + return "NoneFp8"; + case Fp8QuantizationType::DeepSeekFp8: + return "DeepSeekFp8"; + case Fp8QuantizationType::MxFp8: + return "MxFp8"; + case Fp8QuantizationType::PerTensorFp8: + return "PerTensorFp8"; + } +} + // Utility function to compute the next power of two inline int32_t nextPowerOfTwo(float value) { int32_t n = static_cast(std::ceil(value)); @@ -111,6 +132,8 @@ class FusedMoeLauncher { btg::Dtype::Bfloat16}; // Dtype for expert weights in routing, based on routing bias ActivationType activation_type{ActivationType::Swiglu}; + int64_t intermediate_size_factor{2}; + public: // Constructor that initializes all TensorView members FusedMoeLauncher(const Optional& routing_logits, @@ -134,7 +157,8 @@ class FusedMoeLauncher { weight_layout{batchedGemm::gemm::MatrixLayout::MajorK}, mDtypeAct{btg::Dtype::Bfloat16}, mDtypeWeights{btg::Dtype::Bfloat16}, - activation_type{ActivationType::Swiglu} {} + activation_type{ActivationType::Swiglu}, + intermediate_size_factor{2} {} protected: // Initialize common data necessary for later. @@ -315,6 +339,14 @@ class FusedMoeLauncher { args->top_k, args->hidden_size, args->intermediate_size, args->local_num_experts, args->num_tokens); } + auto valid_cfgs = + moe_runner->getValidConfigIndices(args->top_k, args->hidden_size, args->intermediate_size, + args->local_num_experts, args->num_tokens); + auto valid_it = std::find(valid_cfgs.begin(), valid_cfgs.end(), moe_tactic); + FLASHINFER_CHECK(valid_it != valid_cfgs.end(), "Invalid MoE tactic ", moe_tactic, + " for tile_N=", tile_tokens_dim, ". Number of valid tactics for this tile is ", + valid_cfgs.size(), + ". This often indicates a stale or mismatched autotuner cache entry."); this->moe_tactic = moe_tactic; auto workspace_sizes = moe_runner->getWorkspaceSizeInBytes(*args, moe_tactic); @@ -400,6 +432,7 @@ void FusedMoeLauncher::init_common( << "the value of weight_layout is not recognized"; this->weight_layout = static_cast(weight_layout); this->activation_type = activation_type; + this->intermediate_size_factor = isGatedActivation(activation_type) ? 2 : 1; } class Bf16MoeLauncher : public FusedMoeLauncher { @@ -707,14 +740,22 @@ class Fp8PerTensorLauncher : public FusedMoeLauncher { class Fp8BlockScaleLauncher : public FusedMoeLauncher { public: - static constexpr std::array mSupportedTileNums = {8, 16, 32, 64, 128}; + static constexpr std::array mBaseSupportedTileNums = {8, 16, 32, 64, 128}; + + static std::vector getSupportedTileNums(Fp8QuantizationType quantization_type) { + std::vector tiles(mBaseSupportedTileNums.begin(), mBaseSupportedTileNums.end()); + if (quantization_type == Fp8QuantizationType::MxFp8) { + tiles.push_back(256); + } + return tiles; + } Fp8BlockScaleLauncher(Optional const& routing_logits, Optional const& routing_bias, TensorView const& hidden_states, TensorView const& hidden_states_scale, TensorView const& gemm1_weights, TensorView const& gemm1_weights_scale, TensorView const& gemm2_weights, TensorView const& gemm2_weights_scale, TensorView const& expert_indices, - TensorView const& expert_weights) + TensorView const& expert_weights, Fp8QuantizationType quantization_type) : FusedMoeLauncher(routing_logits, routing_bias, hidden_states, gemm1_weights, Optional(), Optional(), gemm2_weights, Optional()), @@ -722,15 +763,21 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { gemm1_weights_scale(gemm1_weights_scale), gemm2_weights_scale(gemm2_weights_scale), expert_indices(expert_indices), - expert_weights(expert_weights) {} + expert_weights(expert_weights), + quantization_type(quantization_type) {} void init(std::unique_ptr&& args, int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout) { constexpr ActivationType activation_type = ActivationType::Swiglu; - mDtypeAct = btg::Dtype::E4m3; - mDtypeWeights = btg::Dtype::E4m3; + if (quantization_type == Fp8QuantizationType::MxFp8) { + mDtypeAct = btg::Dtype::MxE4m3; + mDtypeWeights = btg::Dtype::MxE4m3; + } else { + mDtypeAct = btg::Dtype::E4m3; + mDtypeWeights = btg::Dtype::E4m3; + } auto dtype = hidden_states.dtype(); if (dtype == dl_float16) { @@ -817,7 +864,7 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported input dtype for MoE."; } - args->mUseDeepSeekFp8 = true; + args->mUseDeepSeekFp8 = quantization_type == Fp8QuantizationType::DeepSeekFp8; // Check ndim==2 and size>0 because empty placeholder tensors may have non-null data_ptr bool has_precomputed_indices = expert_indices.ndim() == 2 && expert_indices.size(0) > 0; if (has_precomputed_indices) { @@ -848,43 +895,72 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { FusedMoeLauncher::check_moe_common(); TVM_FFI_ICHECK_EQ(hidden_states.dtype(), dl_float8_e4m3fn) << "hidden_states must be fp8."; - TVM_FFI_ICHECK_EQ(hidden_states_scale.dtype(), dl_float32) - << "hidden_states_scale must be float."; - TVM_FFI_ICHECK_EQ(hidden_states_scale.ndim(), 2) << "hidden_states_scale must be 2D."; - TVM_FFI_ICHECK_EQ(hidden_states_scale.size(0), hidden_states.size(1) / 128) - << "hidden_states_scale dim0 must match hidden_states dim1 / 128."; - TVM_FFI_ICHECK_EQ(hidden_states_scale.size(1), args->num_tokens) - << "hidden_states_scale dim1 must match num_tokens."; + if (quantization_type == Fp8QuantizationType::DeepSeekFp8) { + TVM_FFI_ICHECK_EQ(hidden_states_scale.dtype(), dl_float32) + << "hidden_states_scale must be float."; + TVM_FFI_ICHECK_EQ(hidden_states_scale.ndim(), 2) << "hidden_states_scale must be 2D."; + TVM_FFI_ICHECK_EQ(hidden_states_scale.size(0), hidden_states.size(1) / 128) + << "hidden_states_scale dim0 must match hidden_states dim1 / 128."; + TVM_FFI_ICHECK_EQ(hidden_states_scale.size(1), args->num_tokens) + << "hidden_states_scale dim1 must match num_tokens."; + } else if (quantization_type == Fp8QuantizationType::MxFp8) { + TVM_FFI_CHECK(weight_layout == batchedGemm::gemm::MatrixLayout::MajorK, + "weight_layout must be MajorK for MxFp8."); + TVM_FFI_ICHECK_EQ(hidden_states_scale.dtype(), dl_uint8); + } else { + TVM_FFI_LOG_AND_THROW(NotImplementedError) + << "trtllm_fp8_block_scale_moe only supports DeepSeekFp8 or MxFp8."; + } TVM_FFI_ICHECK_EQ(gemm1_weights.dtype(), dl_float8_e4m3fn) << "gemm1_weights must be fp8."; TVM_FFI_ICHECK_EQ(gemm2_weights.dtype(), dl_float8_e4m3fn) << "gemm2_weights must be fp8."; - TVM_FFI_ICHECK_EQ(gemm1_weights_scale.dtype(), dl_float32) - << "gemm1_weights_scale must be float."; - TVM_FFI_ICHECK_EQ(gemm1_weights_scale.ndim(), 3) << "gemm1_weights_scale must be 3D."; - TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(0), args->local_num_experts) - << "gemm1_weights_scale has incorrect shape."; - TVM_FFI_ICHECK_EQ(args->intermediate_size % 128, 0) - << "intermediate_size must be a multiple of 128."; - TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(1), 2 * args->intermediate_size / 128) - << "gemm1_weights_scale has incorrect shape."; - TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(2), args->hidden_size / 128) - << "gemm1_weights_scale has incorrect shape."; + if (quantization_type == Fp8QuantizationType::DeepSeekFp8) { + TVM_FFI_ICHECK_EQ(gemm1_weights_scale.dtype(), dl_float32) + << "gemm1_weights_scale must be float."; + TVM_FFI_ICHECK_EQ(gemm1_weights_scale.ndim(), 3) << "gemm1_weights_scale must be 3D."; + TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(0), args->local_num_experts) + << "gemm1_weights_scale has incorrect shape."; + TVM_FFI_ICHECK_EQ(args->intermediate_size % 128, 0) + << "intermediate_size must be a multiple of 128."; + TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(1), + intermediate_size_factor * args->intermediate_size / 128) + << "gemm1_weights_scale has incorrect shape."; + TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(2), args->hidden_size / 128) + << "gemm1_weights_scale has incorrect shape."; + } else if (quantization_type == Fp8QuantizationType::MxFp8) { + TVM_FFI_ICHECK_EQ(gemm1_weights_scale.dtype(), dl_uint8) + << "gemm1_weights_scale must be uint8."; + } else { + TVM_FFI_LOG_AND_THROW(NotImplementedError) + << "trtllm_fp8_block_scale_moe only supports DeepSeekFp8 or MxFp8."; + } - TVM_FFI_ICHECK_EQ(gemm2_weights_scale.dtype(), dl_float32) - << "gemm2_weights_scale must be float."; - TVM_FFI_ICHECK_EQ(gemm2_weights_scale.ndim(), 3) << "gemm2_weights_scale must be 3D."; - TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(0), args->local_num_experts) - << "gemm2_weights_scale has incorrect shape."; - TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(1), args->hidden_size / 128) - << "gemm2_weights_scale has incorrect shape."; - TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(2), args->intermediate_size / 128) - << "gemm2_weights_scale has incorrect shape."; + if (quantization_type == Fp8QuantizationType::DeepSeekFp8) { + TVM_FFI_ICHECK_EQ(gemm2_weights_scale.dtype(), dl_float32) + << "gemm2_weights_scale must be float."; + TVM_FFI_ICHECK_EQ(gemm2_weights_scale.ndim(), 3) << "gemm2_weights_scale must be 3D."; + TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(0), args->local_num_experts) + << "gemm2_weights_scale has incorrect shape."; + TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(1), args->hidden_size / 128) + << "gemm2_weights_scale has incorrect shape."; + TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(2), args->intermediate_size / 128) + << "gemm2_weights_scale has incorrect shape."; + } else if (quantization_type == Fp8QuantizationType::MxFp8) { + TVM_FFI_ICHECK_EQ(gemm2_weights_scale.dtype(), dl_uint8) + << "gemm2_weights_scale must be uint8."; + } else { + TVM_FFI_LOG_AND_THROW(NotImplementedError) + << "trtllm_fp8_block_scale_moe only supports DeepSeekFp8 or MxFp8."; + } check_weights_shape("gemm1"); check_weights_shape("gemm2"); - TVM_FFI_ICHECK_EQ(args->intermediate_size % 128, 0) - << "intermediate_size must be a multiple of 128."; + + if (quantization_type == Fp8QuantizationType::DeepSeekFp8) { + TVM_FFI_ICHECK_EQ(args->intermediate_size % 128, 0) + << "intermediate_size must be a multiple of 128."; + } } void prepare_moe(int64_t& moe_tactic) override { @@ -900,17 +976,28 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { workspace.total_max_padded_tokens, args->hidden_size, btg::dtypeGetNumBits(args->mDtypeOut)); - gemm1_output = alloc_tensor({max_num_padded_tokens_gemm1, 2 * args->intermediate_size}, - dl_uint8, hidden_states.device()); - gemm1_output_scale = - alloc_tensor({2 * args->intermediate_size / 128, workspace.total_max_padded_tokens}, - dl_float32, hidden_states.device()); + gemm1_output = alloc_tensor( + {max_num_padded_tokens_gemm1, intermediate_size_factor * args->intermediate_size}, dl_uint8, + hidden_states.device()); + + if (quantization_type == Fp8QuantizationType::DeepSeekFp8) { + gemm1_output_scale = alloc_tensor({intermediate_size_factor * args->intermediate_size / 128, + workspace.total_max_padded_tokens}, + dl_float32, hidden_states.device()); + } else if (quantization_type == Fp8QuantizationType::MxFp8) { + // MxFP8 fuses the activation so no need for intermediate_size_factor + int64_t sf_size = tensorrt_llm::computeSwizzledLayoutSFSize(max_num_padded_tokens_gemm1, + args->intermediate_size / 32); + gemm1_output_scale = alloc_tensor({sf_size}, dl_uint8, hidden_states.device()); + } - activation_output = alloc_tensor({max_num_padded_tokens_gemm1, args->intermediate_size}, - dl_uint8, hidden_states.device()); - activation_output_scale = - alloc_tensor({args->intermediate_size / 128, max_num_padded_tokens_gemm1}, dl_float32, - hidden_states.device()); + if (quantization_type == Fp8QuantizationType::DeepSeekFp8) { + activation_output = alloc_tensor({max_num_padded_tokens_gemm1, args->intermediate_size}, + dl_uint8, hidden_states.device()); + activation_output_scale = + alloc_tensor({args->intermediate_size / 128, max_num_padded_tokens_gemm1}, dl_float32, + hidden_states.device()); + } gemm2_output = alloc_tensor({max_num_padded_tokens_gemm2, args->hidden_size}, dl_bfloat16, hidden_states.device()); @@ -918,8 +1005,10 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { workspace.hidden_states_scale_linear = nullptr; workspace.gemm1_output = gemm1_output.data_ptr(); workspace.gemm1_output_scale = static_cast(gemm1_output_scale.data_ptr()); - workspace.activation_output = activation_output.data_ptr(); - workspace.activation_output_scale = static_cast(activation_output_scale.data_ptr()); + if (quantization_type == Fp8QuantizationType::DeepSeekFp8) { + workspace.activation_output = activation_output.data_ptr(); + workspace.activation_output_scale = static_cast(activation_output_scale.data_ptr()); + } workspace.gemm2_output = gemm2_output.data_ptr(); workspace.gemm2_output_scale = nullptr; @@ -942,6 +1031,7 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { Tensor activation_output_scale; TensorView expert_indices; TensorView expert_weights; + Fp8QuantizationType quantization_type; public: // Override to handle pre-computed routing @@ -990,17 +1080,18 @@ class Fp8BlockScaleLauncher : public FusedMoeLauncher { static Array> getValidConfigs(int64_t top_k, int64_t hidden_size, int64_t intermediate_size, int64_t num_local_experts, int64_t num_tokens, bool use_shuffled_weight, - int64_t weight_layout, btg::Dtype dtype_weights) { + int64_t weight_layout, btg::Dtype dtype_weights, + Fp8QuantizationType quantization_type) { Array> valid_configs; - std::vector supported_tile_nums(mSupportedTileNums.begin(), mSupportedTileNums.end()); + auto supported_tile_nums = getSupportedTileNums(quantization_type); std::set selected_tile_nums = computeSelectedTileN(supported_tile_nums, num_tokens, top_k, num_local_experts); for (int32_t tile_N : selected_tile_nums) { auto moe_runner = std::make_unique( - dtype_weights, // dtype_weights for DeepSeek FP8 - true, // useDeepSeekFp8 + dtype_weights, // dtype_weights for DeepSeek FP8 + quantization_type == Fp8QuantizationType::DeepSeekFp8, // useDeepSeekFp8 tile_N, use_shuffled_weight, static_cast(weight_layout)); auto cfgs = moe_runner->getValidConfigIndices(top_k, hidden_size, intermediate_size, @@ -1649,7 +1740,7 @@ Tensor trtllm_fp8_block_scale_moe( Optional n_group, Optional topk_group, int64_t intermediate_size, int64_t local_expert_offset, int64_t local_num_experts, Optional routed_scaling_factor, int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout, bool enable_pdl, - Array config_index) { + Array config_index, Fp8QuantizationType quantization_type) { // Basic type validation auto dtype = hidden_states.dtype(); @@ -1673,24 +1764,46 @@ Tensor trtllm_fp8_block_scale_moe( } TVM_FFI_ICHECK(dtype == dl_float16 || dtype == dl_bfloat16 || dtype == dl_float8_e4m3fn) << "FP8 block scale MoE: hidden_states must be fp16, bf16, or fp8."; - TVM_FFI_ICHECK_EQ(hidden_states_scale.dtype(), dl_float32) - << "FP8 block scale MoE: hidden_states_scale must be float32."; + if (quantization_type == Fp8QuantizationType::DeepSeekFp8) { + TVM_FFI_ICHECK_EQ(hidden_states_scale.dtype(), dl_float32) + << "FP8 block scale MoE: hidden_states_scale must be float32."; + } else if (quantization_type == Fp8QuantizationType::MxFp8) { + TVM_FFI_ICHECK_EQ(hidden_states_scale.dtype(), dl_uint8) + << "FP8 block scale MoE: hidden_states_scale must be uint8."; + } else { + TVM_FFI_LOG_AND_THROW(NotImplementedError) + << "trtllm_fp8_block_scale_moe only supports DeepSeekFp8 or MxFp8."; + } TVM_FFI_ICHECK_EQ(gemm1_weights.dtype(), dl_float8_e4m3fn) << "FP8 block scale MoE: gemm1_weights must be fp8."; - TVM_FFI_ICHECK_EQ(gemm1_weights_scale.dtype(), dl_float32) - << "FP8 block scale MoE: gemm1_weights_scale must be float32."; TVM_FFI_ICHECK_EQ(gemm2_weights.dtype(), dl_float8_e4m3fn) << "FP8 block scale MoE: gemm2_weights must be fp8."; - TVM_FFI_ICHECK_EQ(gemm2_weights_scale.dtype(), dl_float32) - << "FP8 block scale MoE: gemm2_weights_scale must be float32."; + if (quantization_type == Fp8QuantizationType::DeepSeekFp8) { + TVM_FFI_ICHECK_EQ(gemm1_weights_scale.dtype(), dl_float32) + << "FP8 block scale MoE: gemm1_weights_scale must be float32."; + TVM_FFI_ICHECK_EQ(gemm2_weights_scale.dtype(), dl_float32) + << "FP8 block scale MoE: gemm2_weights_scale must be float32."; + } else if (quantization_type == Fp8QuantizationType::MxFp8) { + TVM_FFI_ICHECK_EQ(gemm1_weights_scale.dtype(), dl_uint8) + << "FP8 block scale MoE: gemm1_weights_scale must be uint8."; + TVM_FFI_ICHECK_EQ(gemm2_weights_scale.dtype(), dl_uint8) + << "FP8 block scale MoE: gemm2_weights_scale must be uint8."; + } else { + TVM_FFI_LOG_AND_THROW(NotImplementedError) + << "trtllm_fp8_block_scale_moe only supports DeepSeekFp8 or MxFp8."; + } + + if (quantization_type == Fp8QuantizationType::MxFp8) { + TVM_FFI_ICHECK(use_shuffled_weight) << "use_shuffled_weight must be true for MxFp8."; + TVM_FFI_ICHECK(weight_layout == 0) << "weight_layout must be 0 for MxFp8."; + } auto const num_tokens = hidden_states.size(0); auto const hidden_size = hidden_states.size(1); - std::vector mSupportedTileN(Fp8BlockScaleLauncher::mSupportedTileNums.begin(), - Fp8BlockScaleLauncher::mSupportedTileNums.end()); + auto supported_tile_nums = Fp8BlockScaleLauncher::getSupportedTileNums(quantization_type); std::set selected_tile_nums = - computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts); + computeSelectedTileN(supported_tile_nums, num_tokens, top_k, local_num_experts); // Create a map of launchers for each tile size std::unordered_map> launchers_map; @@ -1713,7 +1826,8 @@ Tensor trtllm_fp8_block_scale_moe( // Create and initialize launcher for this tile size auto launcher = std::make_unique( routing_logits, routing_bias, hidden_states, hidden_states_scale, gemm1_weights, - gemm1_weights_scale, gemm2_weights, gemm2_weights_scale, expert_indices, expert_weights); + gemm1_weights_scale, gemm2_weights, gemm2_weights_scale, expert_indices, expert_weights, + quantization_type); launcher->init(std::move(args), curr_tile_N, routing_method_type, use_shuffled_weight, weight_layout); @@ -1733,8 +1847,9 @@ Tensor trtllm_fp8_block_scale_moe( auto& selected_launcher = launchers_map.at(tile_N); // Run the launcher with DeepSeek FP8 enabled - it will create its own runner internally - auto result = selected_launcher->run(config, enable_pdl, false /* use_routing_scales_on_input */, - true /* use_deep_seek_fp8 */)[0]; + auto result = selected_launcher->run( + config, enable_pdl, false /* use_routing_scales_on_input */, + quantization_type == Fp8QuantizationType::DeepSeekFp8 /* use_deep_seek_fp8 */)[0]; // Return the result tensor return result; } @@ -1971,7 +2086,7 @@ Array trtllm_mxint4_block_scale_moe( } Array> trtllm_get_valid_moe_configs( - int64_t const dtype_act_, int64_t const dtype_weights_, bool const useDeepSeekFp8, + int64_t const dtype_act_, int64_t const dtype_weights_, Fp8QuantizationType quantization_type, int64_t const top_k, int64_t const hidden_size, int64_t const intermediate_size, int64_t const num_local_experts, int64_t const act_type, bool const use_shuffled_weight, int64_t const weight_layout, int64_t const num_tokens) { @@ -1991,17 +2106,21 @@ Array> trtllm_get_valid_moe_configs( } else if (dtype_act == btg::Dtype::E4m3 && dtype_weights == btg::Dtype::E4m3) { // FP8 - if (!useDeepSeekFp8) { + if (quantization_type == Fp8QuantizationType::DeepSeekFp8) { + // FP8 block scale + return Fp8BlockScaleLauncher::getValidConfigs( + top_k, hidden_size, intermediate_size, num_local_experts, num_tokens, use_shuffled_weight, + weight_layout, dtype_weights, quantization_type); + } else { // FP8 per-tensor scale return Fp8PerTensorLauncher::getValidConfigs( top_k, hidden_size, intermediate_size, num_local_experts, num_tokens, act_type, use_shuffled_weight, weight_layout, dtype_act, dtype_weights); - } else { - // FP8 block scale - return Fp8BlockScaleLauncher::getValidConfigs( - top_k, hidden_size, intermediate_size, num_local_experts, num_tokens, use_shuffled_weight, - weight_layout, dtype_weights); } + } else if (dtype_act == btg::Dtype::MxE4m3 && dtype_weights == btg::Dtype::MxE4m3) { + return Fp8BlockScaleLauncher::getValidConfigs( + top_k, hidden_size, intermediate_size, num_local_experts, num_tokens, use_shuffled_weight, + weight_layout, dtype_weights, quantization_type); } else if (dtype_weights == btg::Dtype::E2m1 || dtype_weights == btg::Dtype::MxE2m1) { // FP4 block scale return FP4BlockScaleLauncher::getValidConfigs(top_k, hidden_size, intermediate_size, @@ -2010,9 +2129,10 @@ Array> trtllm_get_valid_moe_configs( } TVM_FFI_LOG_AND_THROW(NotImplementedError) - << "Unsupported data type combination for getValidConfigs: " << "dtype_act=" - << static_cast(dtype_act) << ", dtype_weights=" << static_cast(dtype_weights) - << ", useDeepSeekFp8=" << useDeepSeekFp8; + << "Unsupported data type combination for getValidConfigs: " + << "dtype_act=" << static_cast(dtype_act) + << ", dtype_weights=" << static_cast(dtype_weights) + << ", quantization_type=" << fp8QuantizationTypeToString(quantization_type); // Unreachable code - added to suppress compiler warning return Array>(); diff --git a/csrc/trtllm_fused_moe_runner.cu b/csrc/trtllm_fused_moe_runner.cu index e3615fa1c4..a1bf8139cc 100644 --- a/csrc/trtllm_fused_moe_runner.cu +++ b/csrc/trtllm_fused_moe_runner.cu @@ -517,6 +517,9 @@ void Runner::setOpsData(MoERunnerArgs const& args, MoEWorkspace const& workspace std::tuple Runner::getWorkspaceSizeInBytes(MoERunnerArgs const& args, int64_t configIndex) const { + FLASHINFER_CHECK(configIndex >= 0 && configIndex < static_cast(mPassingConfigs.size()), + "Invalid MoE config index ", configIndex, ", valid range is [0, ", + static_cast(mPassingConfigs.size()) - 1, "]."); auto const& config = mPassingConfigs[configIndex]; auto workspace_size_fc1 = static_cast(mPermuteGemm1.getWorkspaceSizeInBytes( @@ -567,6 +570,9 @@ int64_t Runner::getDefaultValidConfigIndex(int32_t topK, int32_t hiddenSize, void Runner::run(MoERunnerArgs const& args, MoEWorkspace const& workspace, int device, cudaStream_t stream, int64_t configIndex, bool enable_pdl) { + FLASHINFER_CHECK(configIndex >= 0 && configIndex < static_cast(mPassingConfigs.size()), + "Invalid MoE config index ", configIndex, ", valid range is [0, ", + static_cast(mPassingConfigs.size()) - 1, "]."); // Setup all operation data moe::dev::activation::Data activationData; moe::dev::finalize::Data finalizeData; diff --git a/flashinfer/artifacts.py b/flashinfer/artifacts.py index 14f6164972..cce73a6827 100644 --- a/flashinfer/artifacts.py +++ b/flashinfer/artifacts.py @@ -89,7 +89,7 @@ class ArtifactPath: TRTLLM_GEN_FMHA: str = "e86f0e45764555d070c3d143b4caaea61a45b777/fmha/trtllm-gen/" TRTLLM_GEN_BMM: str = ( - "e1e11bbfe0743743620ef997a6d5e8e2dbdf01cf/batched_gemm-2a674db-79e4d37" + "456b1ae890d436c794b17e4435b41b849d3e5950/batched_gemm-2a674db-3a84a12" ) TRTLLM_GEN_GEMM: str = ( "1fddc48b7b48af33914d040051b3e2ee9ba4701e/gemm-145d1b1-9b113e3" @@ -110,7 +110,7 @@ class CheckSumHash: "c4c93904a4c72b8a3d0d5c525c6decb71c835b477d7d75651ecaaa7007c5a3ef" ) TRTLLM_GEN_BMM: str = ( - "03b1a419b594b7a4613ea8437c172dc2627d56bd360be25aa604859dc12a05fb" + "b9121fed5dd7700b7c2a0dcbcf2ef022483855cf585263324275b0072cca6bb7" ) DEEPGEMM: str = "1a2a166839042dbd2a57f48051c82cd1ad032815927c753db269a4ed10d0ffbf" TRTLLM_GEN_GEMM: str = ( diff --git a/flashinfer/fused_moe/__init__.py b/flashinfer/fused_moe/__init__.py index fac19433fe..00b8a1230f 100644 --- a/flashinfer/fused_moe/__init__.py +++ b/flashinfer/fused_moe/__init__.py @@ -16,6 +16,7 @@ from .core import ( ActivationType, + Fp8QuantizationType, RoutingMethodType, WeightLayout, convert_to_block_layout, @@ -52,6 +53,7 @@ __all__ = [ "ActivationType", + "Fp8QuantizationType", "RoutingMethodType", "WeightLayout", "convert_to_block_layout", diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 2f403078bd..625f4830f1 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -173,6 +173,26 @@ class WeightLayout(IntEnum): BlockMajorK = 2 +# The type of gated activation function +# Please keep this in sync with the counterpart defined in include/flashinfer/trtllm/fused_moe/runner.h +class GatedActType(IntEnum): + # SwiGlu + SwiGlu = 0 + # GeGlu + GeGlu = 1 + + +# The type of FP8 quantization +# Please keep this in sync with the counterpart defined in trtllm_fused_moe_kernel_launcher.cu +class Fp8QuantizationType(IntEnum): + # No FP8 quantization + NoneFp8 = 0 + # DeepSeek FP8 + DeepSeekFp8 = 1 + # MxFp8 x MxFp8 + MxFp8 = 2 + + @functools.cache def is_trtllm_moe_supported( dtype_weights: DtypeTrtllmGen, @@ -986,7 +1006,7 @@ def __init__( num_local_experts: int, dtype_act: DtypeTrtllmGen, dtype_weights: DtypeTrtllmGen, - use_deepseek_fp8: bool, + fp8_quantization_type: Fp8QuantizationType, hidden_size: int, intermediate_size: int, activation_type: int = ActivationType.Swiglu, @@ -998,7 +1018,7 @@ def __init__( self.top_k = top_k self.dtype_act = dtype_act self.dtype_weights = dtype_weights - self.use_deepseek_fp8 = use_deepseek_fp8 + self.fp8_quantization_type = fp8_quantization_type self.top_k = top_k self.hidden_size = hidden_size self.intermediate_size = intermediate_size @@ -1025,7 +1045,7 @@ def get_valid_tactics( instance_key = ( self.dtype_act, self.dtype_weights, - self.use_deepseek_fp8, + self.fp8_quantization_type, self.top_k, self.hidden_size, self.intermediate_size, @@ -1112,18 +1132,33 @@ def forward( elif ( self.dtype_act == DtypeTrtllmGen.E4m3 and self.dtype_weights == DtypeTrtllmGen.E4m3 + ) or ( + self.dtype_act == DtypeTrtllmGen.MxE4m3 + and self.dtype_weights == DtypeTrtllmGen.MxE4m3 ): # FP8 operations - if self.use_deepseek_fp8: + if ( + self.fp8_quantization_type == Fp8QuantizationType.DeepSeekFp8 + or self.fp8_quantization_type == Fp8QuantizationType.MxFp8 + ): # FP8 block scale current_num_tokens = hidden_states.shape[0] current_hidden_size = hidden_states.shape[1] - current_hidden_states_scale = torch.full( - (current_hidden_size // 128, current_num_tokens), - 2.0, - dtype=torch.float, - device=hidden_states.device, - ) + if self.fp8_quantization_type == Fp8QuantizationType.DeepSeekFp8: + current_hidden_states_scale = torch.full( + (current_hidden_size // 128, current_num_tokens), + 2.0, + dtype=torch.float, + device=hidden_states.device, + ) + elif self.fp8_quantization_type == Fp8QuantizationType.MxFp8: + current_hidden_states_scale = extra_inputs[0] + + else: + raise ValueError( + f"Unsupported FP8 quantization type: {self.fp8_quantization_type}" + ) + moe_op.trtllm_fp8_block_scale_moe( routing_logits, topk_ids, @@ -1149,6 +1184,7 @@ def forward( kwargs["weight_layout"], kwargs["enable_pdl"], [-1, -1] if tactic == -1 else tactic, + self.fp8_quantization_type, ) else: # FP8 per tensor scale @@ -1319,7 +1355,7 @@ def trtllm_bf16_moe_op( num_local_experts=local_num_experts, dtype_act=dtype_act, dtype_weights=dtype_weights, - use_deepseek_fp8=False, + fp8_quantization_type=Fp8QuantizationType.NoneFp8, hidden_size=hidden_size, intermediate_size=intermediate_size, weight_layout=weight_layout, @@ -1452,7 +1488,7 @@ def trtllm_fp8_per_tensor_scale_moe_op( num_local_experts=local_num_experts, dtype_act=dtype_act, dtype_weights=dtype_weights, - use_deepseek_fp8=False, # per_tensor mode + fp8_quantization_type=Fp8QuantizationType.NoneFp8, # per_tensor mode hidden_size=hidden_size, intermediate_size=intermediate_size, weight_layout=WeightLayout.MajorK, @@ -1569,6 +1605,7 @@ def trtllm_fp8_block_scale_moe_op( weight_layout: int = 0, enable_pdl: Optional[bool] = None, tune_max_num_tokens: int = 8192, + fp8_quantization_type: Fp8QuantizationType = Fp8QuantizationType.DeepSeekFp8, ) -> torch.Tensor: # Determine routing mode: compute from logits or use pre-computed if routing_logits is None: @@ -1611,15 +1648,23 @@ def trtllm_fp8_block_scale_moe_op( else torch.empty(0, dtype=routing_dtype, device=hidden_states.device) ) - dtype_act = DtypeTrtllmGen.E4m3 # FP8 activation - dtype_weights = DtypeTrtllmGen.E4m3 # FP8 weights + dtype_act = ( + DtypeTrtllmGen.E4m3 + if fp8_quantization_type == Fp8QuantizationType.DeepSeekFp8 + else DtypeTrtllmGen.MxE4m3 + ) # FP8 activation + dtype_weights = ( + DtypeTrtllmGen.E4m3 + if fp8_quantization_type == Fp8QuantizationType.DeepSeekFp8 + else DtypeTrtllmGen.MxE4m3 + ) # FP8 weights moe_runner = MoERunner( top_k=top_k, num_local_experts=local_num_experts, dtype_act=dtype_act, dtype_weights=dtype_weights, - use_deepseek_fp8=True, # block_scale mode + fp8_quantization_type=fp8_quantization_type, # block_scale mode hidden_size=hidden_size, intermediate_size=intermediate_size, weight_layout=weight_layout, @@ -1682,6 +1727,7 @@ def trtllm_fp8_block_scale_moe_op( weight_layout, enable_pdl, [-1, -1] if tactic == -1 else tactic, + fp8_quantization_type, ) return result @@ -1712,6 +1758,7 @@ def _fake_trtllm_fp8_block_scale_moe( weight_layout: int = 0, enable_pdl: Optional[bool] = None, tune_max_num_tokens: int = 8192, + fp8_quantization_type: Fp8QuantizationType = Fp8QuantizationType.DeepSeekFp8, ): seq_len = hidden_states.shape[0] hidden_size = hidden_states.shape[1] @@ -1809,7 +1856,7 @@ def trtllm_fp4_block_scale_moe_op( num_local_experts=num_local_experts, dtype_act=dtype_act, dtype_weights=dtype_weights, - use_deepseek_fp8=False, + fp8_quantization_type=Fp8QuantizationType.NoneFp8, hidden_size=hidden_size, intermediate_size=intermediate_size, activation_type=activation_type, @@ -2007,7 +2054,7 @@ def trtllm_mxint4_block_scale_moe_op( num_local_experts=num_local_experts, dtype_act=dtype_act, dtype_weights=dtype_weights, - use_deepseek_fp8=False, + fp8_quantization_type=Fp8QuantizationType.NoneFp8, hidden_size=hidden_size, intermediate_size=intermediate_size, activation_type=ActivationType.Swiglu, @@ -2298,6 +2345,7 @@ def trtllm_fp8_block_scale_moe( weight_layout: int = 0, enable_pdl: Optional[bool] = None, tune_max_num_tokens: int = 8192, + fp8_quantization_type: Fp8QuantizationType = Fp8QuantizationType.DeepSeekFp8, ) -> torch.Tensor: """FP8 block scale MoE operation. @@ -2309,11 +2357,11 @@ def trtllm_fp8_block_scale_moe( gemm1_weights: tensor of first layer weights - [num_experts, 2*intermediate_size, hidden_size] if weight_layout == WeightLayout.MajorK - [num_experts, 2*intermediate_size // 128, hidden_size, 128] if weight_layout == WeightLayout.BlockMajorK - gemm1_weights_scale: [num_experts, 2*intermediate_size//128, hidden_size//128] tensor of first layer block scales + gemm1_weights_scale: [num_experts, 2*intermediate_size//(32 if mxfp8 else 128), hidden_size//(32 if mxfp8 else 128)] tensor of first layer block scales gemm2_weights: tensor of second layer weights - [num_experts, hidden_size, intermediate_size] if weight_layout == WeightLayout.MajorK - [num_experts, hidden_size//128, intermediate_size, 128] if weight_layout == WeightLayout.BlockMajorK - gemm2_weights_scale: [num_experts, hidden_size//128, intermediate_size//128] tensor of second layer block scales + gemm2_weights_scale: [num_experts, hidden_size//(32 if mxfp8 else 128), intermediate_size//(32 if mxfp8 else 128)] tensor of second layer block scales num_experts: Total number of experts top_k: Number of experts to route to per token n_group: Number of expert groups @@ -2328,6 +2376,7 @@ def trtllm_fp8_block_scale_moe( - 2: BlockMajorK - Blocked along K dimension [K/blockK, Mn, blockK] enable_pdl: Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90. tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 8192) + fp8_quantization_type: Type of FP8 quantization to use (default: DeepSeekFp8) Returns: torch.Tensor: Output tensor of shape [seq_len, hidden_size] """ @@ -2359,6 +2408,7 @@ def trtllm_fp8_block_scale_moe( weight_layout, enable_pdl, tune_max_num_tokens, + fp8_quantization_type, ) @@ -2386,6 +2436,7 @@ def trtllm_fp8_block_scale_routed_moe( enable_pdl: Optional[bool] = None, output: Optional[torch.Tensor] = None, tune_max_num_tokens: int = 8192, + fp8_quantization_type: Fp8QuantizationType = Fp8QuantizationType.DeepSeekFp8, ) -> torch.Tensor: """FP8 block scale MoE operation with pre-computed routing (packed format). @@ -2400,11 +2451,11 @@ def trtllm_fp8_block_scale_routed_moe( Can be created as: (topk_ids.int32 << 16) | expert_weights.bfloat16.view(int16) routing_bias: [num_experts] tensor of routing bias (can be None) hidden_states: [seq_len, hidden_size] tensor of input hidden states - hidden_states_scale: [hidden_size//128, seq_len] tensor of hidden states block scales + hidden_states_scale: [hidden_size//(32 if mxfp8 else 128), seq_len] tensor of hidden states block scales gemm1_weights: [num_experts, 2*intermediate_size, hidden_size] tensor of first layer weights - gemm1_weights_scale: [num_experts, 2*intermediate_size//128, hidden_size//128] tensor of first layer block scales + gemm1_weights_scale: [num_experts, 2*intermediate_size//(32 if mxfp8 else 128), hidden_size//(32 if mxfp8 else 128)] tensor of first layer block scales gemm2_weights: [num_experts, hidden_size, intermediate_size] tensor of second layer weights - gemm2_weights_scale: [num_experts, hidden_size//128, intermediate_size//128] tensor of second layer block scales + gemm2_weights_scale: [num_experts, hidden_size//(32 if mxfp8 else 128), intermediate_size//(32 if mxfp8 else 128)] tensor of second layer block scales num_experts: Total number of experts top_k: Number of experts to route to per token n_group: Number of expert groups @@ -2420,6 +2471,7 @@ def trtllm_fp8_block_scale_routed_moe( output (Optional[torch.Tensor]): shape [seq_len, hidden_size] Optional inplace output tensor. tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 8192) + fp8_quantization_type: Type of FP8 quantization to use (default: DeepSeekFp8) Returns: torch.Tensor: Output tensor of shape [seq_len, hidden_size] """ @@ -2448,6 +2500,7 @@ def trtllm_fp8_block_scale_routed_moe( weight_layout, enable_pdl, tune_max_num_tokens, + fp8_quantization_type, ) diff --git a/tests/moe/test_dpsk_fused_moe_fp8.py b/tests/moe/test_dpsk_fused_moe_fp8.py index cd44f2faf2..35d9aae594 100644 --- a/tests/moe/test_dpsk_fused_moe_fp8.py +++ b/tests/moe/test_dpsk_fused_moe_fp8.py @@ -597,7 +597,7 @@ def test_correctness_dpsk_fp8_fused_moe( class FP8BlockScaleMoe: def __init__(self): self.name = "FP8BlockScale" - self.quant_mode = QuantMode.FP8_BLOCK_SCALE + self.quant_mode = QuantMode.FP8_BLOCK_SCALE_DEEPSEEK moe_impl = FP8BlockScaleMoe() diff --git a/tests/moe/test_trtllm_gen_fused_moe.py b/tests/moe/test_trtllm_gen_fused_moe.py index a93767e457..a49d6c725b 100644 --- a/tests/moe/test_trtllm_gen_fused_moe.py +++ b/tests/moe/test_trtllm_gen_fused_moe.py @@ -30,6 +30,7 @@ mxfp8_quantize, reorder_rows_for_gated_act_gemm, shuffle_matrix_a, + shuffle_matrix_sf_a, ) from flashinfer.autotuner import autotune from flashinfer.fp4_quantization import block_scale_interleave @@ -45,6 +46,7 @@ from flashinfer.fused_moe.core import ( get_w2_permute_indices_with_cache, _maybe_get_cached_w3_w1_permute_indices, + Fp8QuantizationType, ) from .utils import is_gated_activation, skip_checks, QuantMode @@ -828,11 +830,17 @@ def get_tolerances(self): class FP8BlockScaleMoe(Moe): - """FP8 MoE implementation with block scaling (DeepSeek style).""" + """FP8 MoE implementation with block scaling (DeepSeek style or MxFp8 x MxFp8).""" + + def __init__( + self, fp8_quantization_type: QuantMode = QuantMode.FP8_BLOCK_SCALE_DEEPSEEK + ): + super().__init__() + self.fp8_quantization_type = fp8_quantization_type @property def quant_mode(self) -> QuantMode: - return QuantMode.FP8_BLOCK_SCALE + return self.fp8_quantization_type def quantize_weights(self, gemm1_weights, gemm2_weights, hidden_states_sample): """Quantize weights to FP8 with block scaling.""" @@ -842,17 +850,30 @@ def quantize_weights(self, gemm1_weights, gemm2_weights, hidden_states_sample): 2 ] # [num_experts, 2*intermediate_size, hidden_size] - # Quantize weights to FP8 - gemm1_weights_fp8 = gemm1_weights.to(torch.float8_e4m3fn) - gemm1_scales = 2 * torch.rand( - (num_experts, 2 * intermediate_size // 128, hidden_size // 128), - device="cuda", - ).to(torch.float) - - gemm2_weights_fp8 = gemm2_weights.to(torch.float8_e4m3fn) - gemm2_scales = 2 * torch.rand( - (num_experts, hidden_size // 128, intermediate_size // 128), device="cuda" - ).to(torch.float) + if self.fp8_quantization_type == QuantMode.FP8_BLOCK_SCALE_DEEPSEEK: + # Quantize weights to FP8 + gemm1_weights_fp8 = gemm1_weights.to(torch.float8_e4m3fn) + gemm1_scales = 2 * torch.rand( + (num_experts, 2 * intermediate_size // 128, hidden_size // 128), + device="cuda", + ).to(torch.float) + + gemm2_weights_fp8 = gemm2_weights.to(torch.float8_e4m3fn) + gemm2_scales = 2 * torch.rand( + (num_experts, hidden_size // 128, intermediate_size // 128), + device="cuda", + ).to(torch.float) + elif self.fp8_quantization_type == QuantMode.FP8_BLOCK_SCALE_MXFP8: + gemm1_weights_fp8, gemm1_scales = mxfp8_quantize_batches( + gemm1_weights, False + ) + gemm2_weights_fp8, gemm2_scales = mxfp8_quantize_batches( + gemm2_weights, False + ) + else: + raise ValueError( + f"Unsupported FP8 quantization type: {self.fp8_quantization_type}" + ) return { "hidden_states_scale_global": None, # Block scales computed at runtime @@ -864,7 +885,12 @@ def quantize_weights(self, gemm1_weights, gemm2_weights, hidden_states_sample): "gemm2_scales_global": None, } - def quantize_inputs(self, hidden_states: torch.Tensor, hidden_states_scale_global): + def quantize_inputs( + self, + hidden_states: torch.Tensor, + hidden_states_scale_global: torch.Tensor = None, + is_swizzling: bool = False, + ): """For FP8 block scaling, no pre-quantization - everything happens at runtime.""" def to_float8_blockwise( @@ -923,7 +949,21 @@ def to_float8_blockwise( return quantized_x, scales # todo(Yingyi):quantize bf16 to fp8 - hidden_states_quant, hidden_states_scale = to_float8_blockwise(hidden_states) + if self.fp8_quantization_type == QuantMode.FP8_BLOCK_SCALE_DEEPSEEK: + hidden_states_quant, hidden_states_scale = to_float8_blockwise( + hidden_states + ) + elif self.fp8_quantization_type == QuantMode.FP8_BLOCK_SCALE_MXFP8: + hidden_states_quant, hidden_states_scale = mxfp8_quantize( + hidden_states, is_swizzling + ) + hidden_states_scale = hidden_states_scale.view(torch.uint8).reshape( + *hidden_states.shape[:-1], -1 + ) + else: + raise ValueError( + f"Unsupported FP8 quantization type: {self.fp8_quantization_type}" + ) return { "hidden_states": hidden_states_quant, "hidden_states_scale": hidden_states_scale, @@ -948,17 +988,70 @@ def prepare_static_weights_for_kernel( if use_shuffled_weight: # FIXME: this depends on the kernel internals - epilogue_tile_m = 64 + epilogue_tile_m = ( + 64 + if self.fp8_quantization_type == QuantMode.FP8_BLOCK_SCALE_DEEPSEEK + else 128 + ) + + intermediate_size_factor = ( + 2 if is_gated_activation(args.activation_type) else 1 + ) + + gemm1_weights_fp8_interleaved = args.gemm1_weights.clone() + gemm1_scales_fp8_interleaved = args.gemm1_scales.clone() + if self.fp8_quantization_type == QuantMode.FP8_BLOCK_SCALE_MXFP8: + # Reorder rows of W1 for fused gated activation + gemm1_weights_fp8_interleaved = [] + gemm1_scales_fp8_interleaved = [] + for i in range(num_experts): + gemm1_weights_fp8_interleaved.append( + reorder_rows_for_gated_act_gemm( + args.gemm1_weights[i] + .clone() + .reshape(intermediate_size_factor * intermediate_size, -1) + ) + ) + gemm1_scales_fp8_interleaved.append( + reorder_rows_for_gated_act_gemm( + args.gemm1_scales[i] + .clone() + .reshape(intermediate_size_factor * intermediate_size, -1) + ) + ) + + # Stack weights and scales for all experts + gemm1_weights_fp8_interleaved = torch.stack( + gemm1_weights_fp8_interleaved + ).reshape(args.gemm1_weights.shape) + gemm1_scales_fp8_interleaved = torch.stack( + gemm1_scales_fp8_interleaved + ).reshape(args.gemm1_scales.shape) gemm1_weights_fp8_shuffled = [] gemm2_weights_fp8_shuffled = [] + gemm1_scales_fp8_shuffled = [] + gemm2_scales_fp8_shuffled = [] for i in range(num_experts): tmp_weights1 = shuffle_matrix_a( - args.gemm1_weights[i].view(torch.uint8), epilogue_tile_m + gemm1_weights_fp8_interleaved[i].view(torch.uint8), epilogue_tile_m ) tmp_weights2 = shuffle_matrix_a( args.gemm2_weights[i].view(torch.uint8), epilogue_tile_m ) + if self.fp8_quantization_type == QuantMode.FP8_BLOCK_SCALE_MXFP8: + tmp_scales1 = shuffle_matrix_sf_a( + gemm1_scales_fp8_interleaved[i] + .view(torch.uint8) + .reshape(2 * intermediate_size, -1), + epilogue_tile_m, + ) + tmp_scales2 = shuffle_matrix_sf_a( + args.gemm2_scales[i].view(torch.uint8).reshape(hidden_size, -1), + epilogue_tile_m, + ) + gemm1_scales_fp8_shuffled.append(tmp_scales1) + gemm2_scales_fp8_shuffled.append(tmp_scales2) if weight_layout == WeightLayout.BlockMajorK: block_k = 128 @@ -974,15 +1067,27 @@ def prepare_static_weights_for_kernel( kernel_gemm2_weights = torch.stack(gemm2_weights_fp8_shuffled).view( torch.float8_e4m3fn ) + if self.fp8_quantization_type == QuantMode.FP8_BLOCK_SCALE_MXFP8: + kernel_gemm1_scales = torch.stack(gemm1_scales_fp8_shuffled).reshape( + args.gemm1_scales.shape + ) + kernel_gemm2_scales = torch.stack(gemm2_scales_fp8_shuffled).reshape( + args.gemm2_scales.shape + ) + else: + kernel_gemm1_scales = args.gemm1_scales + kernel_gemm2_scales = args.gemm2_scales else: kernel_gemm1_weights = args.gemm1_weights kernel_gemm2_weights = args.gemm2_weights + kernel_gemm1_scales = args.gemm1_scales + kernel_gemm2_scales = args.gemm2_scales return { "gemm1_weights": kernel_gemm1_weights, - "gemm1_scales": args.gemm1_scales, + "gemm1_scales": kernel_gemm1_scales, "gemm2_weights": kernel_gemm2_weights, - "gemm2_scales": args.gemm2_scales, + "gemm2_scales": kernel_gemm2_scales, "use_shuffled_weight": use_shuffled_weight, "weight_layout": weight_layout, } @@ -1011,6 +1116,15 @@ def call_moe( "NaN detected in hidden_states_fp8" ) + if self.fp8_quantization_type == QuantMode.FP8_BLOCK_SCALE_MXFP8: + quantization_mode = Fp8QuantizationType.MxFp8 + elif self.fp8_quantization_type == QuantMode.FP8_BLOCK_SCALE_DEEPSEEK: + quantization_mode = Fp8QuantizationType.DeepSeekFp8 + else: + raise ValueError( + f"Unsupported FP8 quantization type: {self.fp8_quantization_type}" + ) + # Use autotuner for optimal kernel selection with autotune(enable_autotune): output = trtllm_fp8_block_scale_moe( @@ -1035,12 +1149,20 @@ def call_moe( weight_layout=static_data["weight_layout"], enable_pdl=enable_pdl, tune_max_num_tokens=TUNE_MAX_NUM_TOKENS, + fp8_quantization_type=quantization_mode, ) return output.to(torch.float) def compute_reference(self, args): """FP8 block-scale reference implementation.""" - return run_moe_reference_dsfp8(args) + if self.fp8_quantization_type == QuantMode.FP8_BLOCK_SCALE_DEEPSEEK: + return run_moe_reference_dsfp8(args) + elif self.fp8_quantization_type == QuantMode.FP8_BLOCK_SCALE_MXFP8: + return run_moe_reference_mxfp8(args) + else: + raise ValueError( + f"Unsupported FP8 quantization type: {self.fp8_quantization_type}" + ) def get_tolerances(self): """Get FP8 block-scale accuracy tolerances.""" @@ -1403,8 +1525,12 @@ def get_tolerances(self): # ==================================================================================== def get_moe_impl(quant_mode: QuantMode): """Factory function to get the appropriate MoE implementation.""" - if quant_mode == QuantMode.FP8_BLOCK_SCALE: - return FP8BlockScaleMoe() + if quant_mode == QuantMode.FP8_BLOCK_SCALE_DEEPSEEK: + return FP8BlockScaleMoe( + fp8_quantization_type=QuantMode.FP8_BLOCK_SCALE_DEEPSEEK + ) + elif quant_mode == QuantMode.FP8_BLOCK_SCALE_MXFP8: + return FP8BlockScaleMoe(fp8_quantization_type=QuantMode.FP8_BLOCK_SCALE_MXFP8) elif quant_mode == QuantMode.FP8_PER_TENSOR: return FP8PerTensorMoe() else: @@ -1890,6 +2016,39 @@ def dequant_reference_dsfp8(input, scale, transpose_scale, block_m, block_n): return output +def mxfp8_quantize_batches(a, is_swizzling=True): + """MxFp8 batch quantization function with centralized global scale factor calculation.""" + num_batches = a.size(0) + a_quant = [] + a_scales = [] + for i in range(num_batches): + mx_fp8_quant, mx_fp8_scale = mxfp8_quantize(a[i], is_swizzling) + a_quant.append(mx_fp8_quant) + a_scales.append(mx_fp8_scale.view(torch.uint8)) + + result_a_quant = torch.stack(a_quant) + result_a_scales = torch.stack(a_scales) + + return result_a_quant, result_a_scales + + +def mxfp8_dequantize_batches(a, a_scales, is_swizzling=True): + """MxFp8 batch dequantization function.""" + num_batches = a.size(0) + a_dequant = [] + for i in range(num_batches): + mx_fp8_dequant = mxfp8_dequantize_host( + a[i].cpu().view(torch.uint8), + a_scales[i].cpu().view(torch.uint8).reshape(-1), + is_swizzling, + ) + a_dequant.append(mx_fp8_dequant.cuda()) + + result_a_dequant = torch.stack(a_dequant) + + return result_a_dequant + + # ==================================================================================== # Common MoE Reference Implementation # ==================================================================================== @@ -1989,7 +2148,10 @@ def run_moe_dequant(args, quant_mode: QuantMode): ) activation_output = activation_output.to(torch.float) args.c_global_sf = c_global_sf - elif quant_mode == QuantMode.FP4_MXFP4_MXFP8: + elif ( + quant_mode == QuantMode.FP4_MXFP4_MXFP8 + or quant_mode == QuantMode.FP8_BLOCK_SCALE_MXFP8 + ): activation_output, scale_bytes = mxfp8_quantize( activation_output.to(torch.bfloat16), True ) @@ -2105,8 +2267,46 @@ def run_moe_reference_fp4(args, quant_mode: QuantMode): return run_moe_dequant(args_dequant, quant_mode), args_dequant +def run_moe_reference_mxfp8(args): + hidden_states_dequant = mxfp8_dequantize_host( + args.hidden_states.cpu().view(torch.uint8), + args.hidden_states_scale.cpu().view(torch.uint8).reshape(-1), + False, # is_sf_swizzled_layout + ).cuda() + + gemm1_weights_dequant = mxfp8_dequantize_batches( + args.gemm1_weights, + args.gemm1_scales, + False, + ).cuda() + + gemm2_weights_dequant = mxfp8_dequantize_batches( + args.gemm2_weights, + args.gemm2_scales, + False, + ).cuda() + + args_dequant = moe_args_dequant( + args.num_tokens, + args.num_experts, + args.hidden_size, + args.intermediate_size, + args.top_k, + args.padding, + hidden_states_dequant, + args.expert_logits, + gemm1_weights_dequant, + gemm2_weights_dequant, + args.permute_info, + args.use_routing_scales_on_input, + args.activation_type, + ) + + return run_moe_dequant(args_dequant, QuantMode.FP8_BLOCK_SCALE_MXFP8), args_dequant + + def run_moe_reference_dsfp8(args): - """FP8 block-scale reference implementation.""" + """FP8 block-scale reference implementation (DeepSeek style).""" # Generate block scales at runtime for FP8 block scaling def dequant_reference_dsfp8(input, scale, transpose_scale, block_m, block_n): @@ -2167,7 +2367,9 @@ def dequant_reference_dsfp8(input, scale, transpose_scale, block_m, block_n): args.activation_type, ) - return run_moe_dequant(args_dequant, QuantMode.FP8_BLOCK_SCALE), args_dequant + return run_moe_dequant( + args_dequant, QuantMode.FP8_BLOCK_SCALE_DEEPSEEK + ), args_dequant def run_moe_reference_per_tensor_scale_fp8(args): @@ -2552,7 +2754,14 @@ def run_moe_test( "moe_impl", [ pytest.param(BF16Moe(), id="BF16xBF16"), - pytest.param(FP8BlockScaleMoe(), id="FP8_Block"), + pytest.param( + FP8BlockScaleMoe(fp8_quantization_type=QuantMode.FP8_BLOCK_SCALE_DEEPSEEK), + id="FP8_Block_DeepSeek", + ), + pytest.param( + FP8BlockScaleMoe(fp8_quantization_type=QuantMode.FP8_BLOCK_SCALE_MXFP8), + id="FP8_Block_MxFp8", + ), pytest.param(FP8PerTensorMoe(), id="FP8_Tensor"), pytest.param(FP4Moe(quant_mode=QuantMode.FP4_NVFP4_NVFP4), id="NvFP4xNvFP4"), pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_MXFP8), id="MxFP4xMxFP8"), @@ -2704,7 +2913,14 @@ def test_renormalize_routing( "moe_impl", [ pytest.param(FP8PerTensorMoe(), id="FP8_PerTensor"), - pytest.param(FP8BlockScaleMoe(), id="FP8_Block"), + pytest.param( + FP8BlockScaleMoe(fp8_quantization_type=QuantMode.FP8_BLOCK_SCALE_DEEPSEEK), + id="FP8_Block_DeepSeek", + ), + pytest.param( + FP8BlockScaleMoe(fp8_quantization_type=QuantMode.FP8_BLOCK_SCALE_MXFP8), + id="FP8_Block_MxFp8", + ), pytest.param(FP4Moe(quant_mode=QuantMode.FP4_NVFP4_NVFP4), id="NvFP4xNvFP4"), pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_MXFP8), id="MxFP4xMxFP8"), pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_Bf16), id="MxFP4xBf16"), diff --git a/tests/moe/utils.py b/tests/moe/utils.py index fae45f0415..7c8339cecf 100644 --- a/tests/moe/utils.py +++ b/tests/moe/utils.py @@ -19,6 +19,7 @@ from enum import IntEnum from flashinfer import ActivationType, RoutingMethodType from flashinfer.utils import get_compute_capability +from flashinfer.fused_moe import WeightLayout class QuantMode(IntEnum): @@ -27,10 +28,11 @@ class QuantMode(IntEnum): FP4_NVFP4_NVFP4 = 1 FP4_MXFP4_MXFP8 = 2 FP4_MXFP4_Bf16 = 3 - FP8_BLOCK_SCALE = 4 - FP8_PER_TENSOR = 5 - BF16 = 6 - MXINT4_BF16_BF16 = 7 + FP8_BLOCK_SCALE_DEEPSEEK = 4 + FP8_BLOCK_SCALE_MXFP8 = 5 + FP8_PER_TENSOR = 6 + BF16 = 7 + MXINT4_BF16_BF16 = 8 NON_GATED_ACTIVATION_SUPPORTED_QUANT_MODES = [ @@ -120,6 +122,19 @@ def skip_checks( pytest.skip( f"Incompatible: {moe_impl.name} + {weight_processing['use_shuffled_weight']} + {weight_processing['layout']}" ) + if ( + is_fp8_block_scale_moe + and moe_impl.fp8_quantization_type == QuantMode.FP8_BLOCK_SCALE_MXFP8 + and not weight_processing["use_shuffled_weight"] + ): + pytest.skip("use_shuffled_weight must be true for MxFp8.") + if ( + is_fp8_block_scale_moe + and moe_impl.fp8_quantization_type == QuantMode.FP8_BLOCK_SCALE_MXFP8 + and weight_processing["layout"] != WeightLayout.MajorK + ): + pytest.skip("weight_layout must be MajorK for MxFp8.") + if intermediate_size not in routing_config["compatible_intermediate_size"]: pytest.skip( f"Incompatible: intermediate_size={intermediate_size} with {routing_config['routing_method_type'].name} routing ({routing_config['num_experts']} experts)"