From ff8d85e2924f757806346354b91e050e600f4472 Mon Sep 17 00:00:00 2001 From: Daniel Stokes <40156487+djns99@users.noreply.github.com> Date: Wed, 25 Jun 2025 21:23:47 -0700 Subject: [PATCH] feat: Add support for MXFP8 x MXFP4 inputs to MOE plugin Signed-off-by: Daniel Stokes <40156487+djns99@users.noreply.github.com> --- .../include/moe_gemm_kernels.h | 24 +- .../cutlass_kernels/include/moe_kernels.h | 26 + .../include/moe_util_kernels.h | 9 +- .../cutlass_kernels/moe_gemm/moe_kernels.cu | 699 ++++++++++++------ cpp/tensorrt_llm/kernels/quantization.cuh | 98 ++- cpp/tensorrt_llm/thop/CMakeLists.txt | 1 + cpp/tensorrt_llm/thop/moeOp.cpp | 141 +++- cpp/tensorrt_llm/thop/moeUtilOp.cpp | 5 +- .../kernels/mixtureOfExpertsTest.cu | 280 ++++--- .../_torch/custom_ops/torch_custom_ops.py | 11 +- 10 files changed, 937 insertions(+), 357 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h index 6eb51c92f04..7ddd756e0d0 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h @@ -93,8 +93,28 @@ struct TmaWarpSpecializedGroupedGemmInput using NVFP4BlockScaledConfig = cutlass::detail::Sm1xxBlockScaledConfig; using MXFPXBlockScaledConfig = cutlass::detail::Sm1xxBlockScaledConfig; - constexpr static int MinNumRowsAlignmentNVFP4 = cute::size<0>(NVFP4BlockScaledConfig::SfAtom{}); - constexpr static int MinNumRowsAlignmentMXFPX = cute::size<0>(MXFPXBlockScaledConfig::SfAtom{}); + // 128 + // This is the alignment of the weight matrix the fully padded SF will refer to. + // We require the SFs to be aligned to this value (zero padded as needed) + // The weights do not need to be aligned to this value, CUTLASS will handle extra padding + // N here is a short hand for the outer dimension of the GEMM, this applies to both M & N dimension of the GEMM + constexpr static int MinNDimAlignmentNVFP4 = cute::size<0>(NVFP4BlockScaledConfig::SfAtom{}); + constexpr static int MinNDimAlignmentMXFPX = cute::size<0>(MXFPXBlockScaledConfig::SfAtom{}); + + // Block scale vector size * 4 + // This is the alignment of the weight matrix the fully padded SF will refer to. + // We should never actually need to pad a buffer to this alignment + // The weights only need to be aligned to BlockScaleVectorSize, CUTLASS will handle extra padding + // The SFs only need to be aligned to 4 (zero padded as needed) + // K here is a short hand for the inner dimension of the GEMM + constexpr static int MinKDimAlignmentNVFP4 = cute::size<1>(NVFP4BlockScaledConfig::SfAtom{}); + constexpr static int MinKDimAlignmentMXFPX = cute::size<1>(MXFPXBlockScaledConfig::SfAtom{}); + + // Helper function to align a dimension to the SF alignment + constexpr static int64_t alignToSfDim(int64_t dim, int64_t alignment) + { + return (dim + alignment - 1) / alignment * alignment; + } using StrideA = std::remove_pointer_t>; // Use B because they will be swapped diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h index 6adf5cbf348..912c3553bb0 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h @@ -210,6 +210,21 @@ struct QuantParams GemmInputs fc2; } fp8_mxfp4; + // MXFP8 MXFP4 quantization params + // This mode uses block scaled MXFP8 and MXFP4 weights + struct MXFP8MXFP4Inputs + { + struct GemmInputs + { + TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF const* weight_block_scale + = nullptr; // (experts, n, k / 32) + float const* global_scale = nullptr; // (num_experts_per_node, ) + }; + + GemmInputs fc1; + GemmInputs fc2; + } mxfp8_mxfp4; + // FP4 quantization params struct FP4Inputs { @@ -291,6 +306,16 @@ struct QuantParams return qp; } + static QuantParams MXFP8MXFP4(TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF const* fc1_weight_block_scale, + float const* fc1_global_scale, // + TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF const* fc2_weight_block_scale, float const* fc2_global_scale) + { + QuantParams qp; + qp.mxfp8_mxfp4.fc1 = {fc1_weight_block_scale, fc1_global_scale}; + qp.mxfp8_mxfp4.fc2 = {fc2_weight_block_scale, fc2_global_scale}; + return qp; + } + static QuantParams FP4(float const* fc1_act_global_scale, TmaWarpSpecializedGroupedGemmInput::NVFP4ElementSF const* fc1_weight_block_scale, float const* fc1_global_scale, // @@ -298,6 +323,7 @@ struct QuantParams TmaWarpSpecializedGroupedGemmInput::NVFP4ElementSF const* fc2_weight_block_scale, float const* fc2_global_scale, // bool fc1_use_per_expert_act_scale = false, bool fc2_use_per_expert_act_scale = false) + { QuantParams qp; qp.fp4.fc1 = {fc1_use_per_expert_act_scale, fc1_act_global_scale, fc1_weight_block_scale, fc1_global_scale}; diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_util_kernels.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_util_kernels.h index 6b346d730e7..b1676993ded 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_util_kernels.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_util_kernels.h @@ -52,14 +52,13 @@ void threeStepBuildExpertMapsSortFirstToken(int const* token_selected_experts, i int64_t const num_tokens, int64_t const num_experts_per_node, int64_t const num_experts_per_token, int const start_expert_id, cudaStream_t stream); -template +template void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input, ExpandedActivationsType* permuted_output, float const* unpermuted_scales, float* permuted_scales, - int const* permuted_row_to_unpermuted_row, int64_t const num_rows, int64_t const cols, int const k, - int const num_experts_per_node, float const* fc1_act_global_scale, bool use_per_expert_act_scale, + int const* permuted_row_to_unpermuted_row, int64_t const num_rows, int64_t const hidden_size, int const k, + int const num_experts_per_node, QuantParams const& quant_params, bool use_per_expert_act_scale, int64_t* expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, cudaStream_t stream, - void const* prequant_scales = nullptr); + TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, void const* prequant_scales, cudaStream_t stream); template void finalizeMoeRoutingKernelLauncher(GemmOutputType const* expanded_permuted_rows, diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu index 1610546e295..0caf687b569 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu @@ -55,6 +55,8 @@ #include "tensorrt_llm/kernels/preQuantScaleKernel.h" #include "tensorrt_llm/kernels/quantization.cuh" +#include "tensorrt_llm/kernels/cutlass_kernels/include/moe_util_kernels.h" + #ifndef CUDART_VERSION #error CUDART_VERSION Undefined! #elif (CUDART_VERSION >= 11050) @@ -932,19 +934,22 @@ __host__ __device__ constexpr T* safe_inc_ptr(T* ptr, size_t offset) __host__ __device__ constexpr int64_t getOffsetWeightSF(int64_t expert_id, int64_t gemm_n, int64_t gemm_k, TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType scaling_type) { - auto function = [=](int64_t min_alignment, int64_t block_size) + auto function = [=](int64_t min_n_dim_alignment, int64_t min_k_dim_alignment, int64_t block_size) { - int64_t rounded_gemm_n = cute::ceil_div(gemm_n, min_alignment) * min_alignment; + int64_t padded_gemm_n = TmaWarpSpecializedGroupedGemmInput::alignToSfDim(gemm_n, min_n_dim_alignment); + int64_t padded_gemm_k = TmaWarpSpecializedGroupedGemmInput::alignToSfDim(gemm_k, min_k_dim_alignment); assert(gemm_k % block_size == 0); - return expert_id * rounded_gemm_n * gemm_k / block_size; + return expert_id * padded_gemm_n * padded_gemm_k / block_size; }; switch (scaling_type) { case TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX: - return function(TmaWarpSpecializedGroupedGemmInput::MinNumRowsAlignmentMXFPX, + return function(TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX, + TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentMXFPX, TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize); case TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4: - return function(TmaWarpSpecializedGroupedGemmInput::MinNumRowsAlignmentNVFP4, + return function(TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentNVFP4, + TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentNVFP4, TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize); case TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE: return 0; // No scaling factors, no offset } @@ -956,20 +961,27 @@ __host__ __device__ constexpr int64_t getOffsetWeightSF(int64_t expert_id, int64 __host__ __device__ constexpr int64_t getOffsetActivationSF(int64_t expert_id, int64_t token_offset, int64_t gemm_k, TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType scaling_type) { - auto function = [=](int64_t min_alignment, int64_t block_size) - { - // This formulation ensures that sf_offset[i + 1] - sf_offset[i] >= token_offset[i + 1] - token_offset[i]. - int64_t sf_offset = (token_offset + expert_id * (min_alignment - 1)) / min_alignment * min_alignment; + auto function = [=](int64_t min_n_dim_alignment, int64_t min_k_dim_alignment, int64_t block_size) + { + // This formulation ensures that: + // `sf_offset[i + 1] - sf_offset[i] >= padded(token_offset[i + 1] - token_offset[i])` + // is true for all possible token distributions. + int64_t padded_sf_start_offset = TmaWarpSpecializedGroupedGemmInput::alignToSfDim( + token_offset + expert_id * (min_n_dim_alignment - 1), min_n_dim_alignment); + int64_t padded_gemm_k = TmaWarpSpecializedGroupedGemmInput::alignToSfDim(gemm_k, min_k_dim_alignment); assert(gemm_k % block_size == 0); - return sf_offset * gemm_k / block_size; + assert(padded_gemm_k % block_size == 0); + return padded_sf_start_offset * padded_gemm_k / block_size; }; switch (scaling_type) { case TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX: - return function(TmaWarpSpecializedGroupedGemmInput::MinNumRowsAlignmentMXFPX, + return function(TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX, + TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentMXFPX, TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize); case TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4: - return function(TmaWarpSpecializedGroupedGemmInput::MinNumRowsAlignmentNVFP4, + return function(TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentNVFP4, + TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentNVFP4, TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize); case TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE: return 0; // No scaling factors, no offset } @@ -978,15 +990,14 @@ __host__ __device__ constexpr int64_t getOffsetActivationSF(int64_t expert_id, i return 0; } -constexpr static int NVFP4_VEC_SIZE = 16; - -template -__device__ uint32_t quantizePackedFP4Value(ComputeElem& post_act_val, float global_scale_val, +template +__device__ auto quantizePackedFPXValue(ComputeElem& post_act_val, float global_scale_val, int64_t num_tokens_before_expert, int64_t expert_id, int64_t token_id, int64_t elem_idx, int64_t num_cols, - int64_t max_tokens_per_expert, TmaWarpSpecializedGroupedGemmInput::ElementSF* act_sf_flat, + TmaWarpSpecializedGroupedGemmInput::ElementSF* act_sf_flat, TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType scaling_type) { - static constexpr int CVT_FP4_NUM_THREADS_PER_SF = NVFP4_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD; + constexpr bool is_fp8 = std::is_same_v; + static constexpr int NumThreadsPerSF = VecSize / CVT_FP4_ELTS_PER_THREAD; // Quantize the input to FP4 static_assert(std::is_same_v || std::is_same_v); static_assert(ComputeElem::kElements == CVT_FP4_ELTS_PER_THREAD); @@ -1002,41 +1013,67 @@ __device__ uint32_t quantizePackedFP4Value(ComputeElem& post_act_val, float glob = act_sf_flat + getOffsetActivationSF(expert_id, num_tokens_before_expert, num_cols, scaling_type); // Use `token - num_tokens_before_expert` because we want this to be relative to the start of this expert - auto sf_out = cvt_quant_to_fp4_get_sf_out_offset(std::nullopt /* batchIdx */, token_id - num_tokens_before_expert, - elem_idx, std::nullopt /* numRows */, num_cols, act_sf_expert, FP4QuantizationSFLayout::SWIZZLED); + auto sf_out + = cvt_quant_to_fp4_get_sf_out_offset( + std::nullopt /* batchIdx */, token_id - num_tokens_before_expert, elem_idx, std::nullopt /* numRows */, + num_cols, act_sf_expert, FP4QuantizationSFLayout::SWIZZLED); // Do the conversion and set the output and scaling factor - auto func = (scaling_type == TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4) - ? &cvt_warp_fp16_to_fp4 - : &cvt_warp_fp16_to_fp4; - auto res = func(packed_vec, global_scale_val, sf_out); - return res; + auto func = [&]() + { + if constexpr (is_fp8) + { + return [](PackedVec& vec, float /* ignored */, uint8_t* SFout) -> uint64_t + { + static_assert(TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize == VecSize); + return cvt_warp_fp16_to_mxfp8(vec, SFout); + }; + } + else + { + return (scaling_type == TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4) + ? &cvt_warp_fp16_to_fp4 + : &cvt_warp_fp16_to_fp4; + } + }(); + + return func(packed_vec, global_scale_val, sf_out); } +template __device__ void writeSF(int64_t num_tokens_before_expert, int64_t expert_id, int64_t source_token_id, int64_t token_id, - int64_t elem_idx, int64_t num_cols, int64_t max_tokens_per_expert, - TmaWarpSpecializedGroupedGemmInput::ElementSF* act_sf_flat, + int64_t elem_idx, int64_t num_cols, TmaWarpSpecializedGroupedGemmInput::ElementSF* act_sf_flat, TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf) { - static constexpr int CVT_FP4_NUM_THREADS_PER_SF = NVFP4_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD; + static constexpr int NumThreadsPerSF = VecSize / ElementsPerThread; // We need to offset into the scaling factors for just this expert auto act_sf_expert = act_sf_flat + getOffsetActivationSF(expert_id, num_tokens_before_expert, num_cols, - TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4); + (VecSize == TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize) + ? TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4 + : TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX); // Use `token - num_tokens_before_expert` because we want this to be relative to the start of this expert - auto sf_out = cvt_quant_to_fp4_get_sf_out_offset(std::nullopt /* batchIdx */, token_id - num_tokens_before_expert, - elem_idx, std::nullopt /* numRows */, num_cols, act_sf_expert, FP4QuantizationSFLayout::SWIZZLED); + auto sf_out + = cvt_quant_to_fp4_get_sf_out_offset( + std::nullopt /* batchIdx */, token_id - num_tokens_before_expert, elem_idx, std::nullopt /* numRows */, + num_cols, act_sf_expert, FP4QuantizationSFLayout::SWIZZLED); if (sf_out) { - auto const sf_in = cvt_quant_to_fp4_get_sf_out_offset(std::nullopt /* batchIdx */, source_token_id, elem_idx, - std::nullopt /* numRows */, num_cols, const_cast(input_sf), - FP4QuantizationSFLayout::SWIZZLED); - *sf_out = *sf_in; + if (input_sf) + { + auto const sf_in + = cvt_quant_to_fp4_get_sf_out_offset(std::nullopt /* batchIdx */, source_token_id, elem_idx, std::nullopt /* numRows */, + num_cols, const_cast(input_sf), + FP4QuantizationSFLayout::SWIZZLED); + *sf_out = *sf_in; + } + else + { + *sf_out = 0x00; + } } } @@ -1096,7 +1133,6 @@ __device__ void setupFP4BlockScalingFactors(TmaWarpSpecializedGroupedGemmInput& ? TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4 : TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX; layout_info.fpX_block_scaling_factors_A[expert] - = fp4_act_flat + getOffsetActivationSF(expert, num_tokens_before_expert, gemm_k, scaling_type); layout_info.fpX_block_scaling_factors_B[expert] @@ -1208,29 +1244,23 @@ __global__ void computeStridesTmaWarpSpecializedKernel(int64_t const* expert_fir layout_info2.alpha_scale_ptr_array[expert] = alpha_scale_flat2 + expert; } - if (quant_params.fp4.fc1.weight_block_scale) + auto setupIfSelected = [&](auto bs_config, auto quant_type) { - setupFP4BlockScalingFactors(layout_info1, expert, - gemm_m, gemm1_n, gemm1_k, fp4_act_flat1, quant_params.fp4.fc1.weight_block_scale, num_tokens_before_expert); - } - if (quant_params.fp4.fc2.weight_block_scale) - { - setupFP4BlockScalingFactors(layout_info2, expert, - gemm_m, gemm2_n, gemm2_k, fp4_act_flat2, quant_params.fp4.fc2.weight_block_scale, num_tokens_before_expert); - } + if (quant_type.fc1.weight_block_scale) + { + setupFP4BlockScalingFactors(layout_info1, expert, gemm_m, gemm1_n, gemm1_k, + fp4_act_flat1, quant_type.fc1.weight_block_scale, num_tokens_before_expert); + } + if (quant_type.fc2.weight_block_scale) + { + setupFP4BlockScalingFactors(layout_info2, expert, gemm_m, gemm2_n, gemm2_k, + fp4_act_flat2, quant_type.fc2.weight_block_scale, num_tokens_before_expert); + } + }; - if (quant_params.fp8_mxfp4.fc1.weight_block_scale) - { - setupFP4BlockScalingFactors(layout_info1, expert, - gemm_m, gemm1_n, gemm1_k, fp4_act_flat1, quant_params.fp8_mxfp4.fc1.weight_block_scale, - num_tokens_before_expert); - } - if (quant_params.fp8_mxfp4.fc2.weight_block_scale) - { - setupFP4BlockScalingFactors(layout_info2, expert, - gemm_m, gemm2_n, gemm2_k, fp4_act_flat2, quant_params.fp8_mxfp4.fc2.weight_block_scale, - num_tokens_before_expert); - } + setupIfSelected(TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaledConfig{}, quant_params.fp4); + setupIfSelected(TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaledConfig{}, quant_params.fp8_mxfp4); + setupIfSelected(TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaledConfig{}, quant_params.mxfp8_mxfp4); #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.launch_dependents;"); @@ -1398,15 +1428,8 @@ __global__ void computeStridesTmaWarpSpecializedLowLatencyKernel(TmaWarpSpeciali template __host__ __device__ constexpr static U arrayConvert(T const& input) { - using Type = typename U::Element; - static_assert(T::kElements == U::kElements); - U u; -#pragma unroll - for (int i = 0; i < U::kElements; i++) - { - u[i] = static_cast(input[i]); - } - return u; + cutlass::NumericArrayConverter converter; + return converter(input); } // Duplicated and permutes rows for MoE. In addition, reverse the permutation map to help with finalizing routing. @@ -1422,65 +1445,94 @@ __host__ __device__ constexpr static U arrayConvert(T const& input) constexpr static int EXPAND_THREADS_PER_BLOCK = 256; -template +template __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_input, ExpandedActivationsType* permuted_output, float const* unpermuted_scales, float* permuted_scales, - int const* permuted_row_to_unpermuted_row, int64_t const num_rows, int64_t const cols, int64_t const k, + int const* permuted_row_to_unpermuted_row, int64_t const num_tokens, int64_t const hidden_size, int64_t const k, float const* fc1_act_global_scale, bool use_per_expert_act_scale, int64_t const* expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat, TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, int64_t const num_experts_per_node, InputActivationsType const* prequant_scales = nullptr) { + static_assert(BlockScalingType == TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE || !PRE_QUANT_AWQ, + "AWQ and Block Scaling are mutually exclusive"); #ifdef ENABLE_FP4 - constexpr bool is_fp4 = std::is_same_v; - constexpr bool is_fp4_input = is_fp4 && std::is_same_v; - constexpr bool need_fp4_quant = is_fp4 && !std::is_same_v; + constexpr bool is_mxfp8 = std::is_same_v + && BlockScalingType == TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX && !PRE_QUANT_AWQ; + constexpr bool is_mxfp8_input = is_mxfp8 && std::is_same_v; + constexpr bool need_mxfp8_quant = is_mxfp8 && !is_mxfp8_input; + constexpr bool is_nvfp4 = std::is_same_v + && BlockScalingType == TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4 && !PRE_QUANT_AWQ; + constexpr bool is_nvfp4_input = is_nvfp4 && std::is_same_v; + constexpr bool need_nvfp4_quant = is_nvfp4 && !is_nvfp4_input; #else - constexpr bool is_fp4 = false; - constexpr bool is_fp4_input = false; - constexpr bool need_fp4_quant = false; + constexpr bool is_mxfp8 = false; + constexpr bool is_mxfp8_input = false; + constexpr bool need_mxfp8_quant = false; + constexpr bool is_nvfp4 = false; + constexpr bool is_nvfp4_input = false; + constexpr bool need_nvfp4_quant = false; #endif - static_assert(need_fp4_quant || PRE_QUANT_AWQ || std::is_same_v, - "Only FP4 and WINT4_AFP8 supports outputting a different format as part of the expansion"); + static_assert(need_nvfp4_quant || need_mxfp8_quant || PRE_QUANT_AWQ + || std::is_same_v, + "Only NVFP4, MXFP8 and WINT4_AFP8 supports outputting a different format as part of the expansion"); #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.wait;"); #endif - int64_t const num_valid_tokens = expert_first_token_offset[num_experts_per_node]; + constexpr int VecSize = is_nvfp4 ? TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize + : TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize; + + constexpr int64_t ELEM_PER_THREAD + = (is_nvfp4 || is_mxfp8) ? CVT_FP4_ELTS_PER_THREAD : (128 / sizeof_bits::value); + // This should be VecSize * 4 elements + // We assume at least VecSize alignment or the quantization will fail + constexpr int64_t min_k_dim_alignment = is_nvfp4 ? TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentNVFP4 + : TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentMXFPX; + int64_t const padded_hidden_size + = TmaWarpSpecializedGroupedGemmInput::alignToSfDim(hidden_size, min_k_dim_alignment); + + int64_t const num_valid_tokens = expert_first_token_offset[num_experts_per_node]; for (int64_t permuted_row = blockIdx.x; permuted_row < num_valid_tokens; permuted_row += gridDim.x) { int64_t const unpermuted_row = permuted_row_to_unpermuted_row[permuted_row]; // Load 128-bits per thread - constexpr int64_t ELEM_PER_THREAD - = is_fp4 ? CVT_FP4_ELTS_PER_THREAD : (128 / sizeof_bits::value); - constexpr int64_t ELEM_PER_BYTE = is_fp4_input ? 2 : 1; - using DataElem - = std::conditional_t>; - using OutputElem = std::conditional_t; + + constexpr int64_t ELEM_PER_BYTE = is_nvfp4_input ? 2 : 1; + using DataElem = std::conditional_t>>; + using OutputElem = std::conditional_t>>; // Duplicate and permute rows - int64_t const source_k_rank = unpermuted_row / num_rows; - int64_t const source_row = unpermuted_row % num_rows; + int64_t const source_k_rank = unpermuted_row / num_tokens; + int64_t const source_row = unpermuted_row % num_tokens; auto const* source_row_ptr - = reinterpret_cast(unpermuted_input + source_row * cols / ELEM_PER_BYTE); + = reinterpret_cast(unpermuted_input + source_row * hidden_size / ELEM_PER_BYTE); // Cast first to handle when this is FP4 - auto* dest_row_ptr = reinterpret_cast(permuted_output) + permuted_row * cols / ELEM_PER_THREAD; + auto* dest_row_ptr + = reinterpret_cast(permuted_output) + permuted_row * hidden_size / ELEM_PER_THREAD; int64_t const start_offset = threadIdx.x; int64_t const stride = EXPAND_THREADS_PER_BLOCK; - int64_t const num_elems_in_col = cols / ELEM_PER_THREAD; - assert(cols % ELEM_PER_THREAD == 0); + int64_t const num_elems_in_col = hidden_size / ELEM_PER_THREAD; + assert(hidden_size % ELEM_PER_THREAD == 0); + assert(hidden_size % VecSize == 0); - if constexpr (is_fp4) + if constexpr (is_nvfp4 || is_mxfp8) { + static_assert(ELEM_PER_THREAD == 8, "Expecting 8 elements per thread for quantized types"); int64_t expert = findTotalEltsLessThanTarget( expert_first_token_offset, num_experts_per_node, (int64_t) permuted_row + 1) - 1; + + assert(!fc1_act_global_scale || is_nvfp4 && "Global scale is only supported for NVFP4"); size_t act_scale_idx = use_per_expert_act_scale ? expert : 0; float global_scale_val = fc1_act_global_scale ? fc1_act_global_scale[act_scale_idx] : 1.0f; int64_t num_tokens_before_expert = expert_first_token_offset[expert]; @@ -1488,37 +1540,46 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) { auto in_vec = source_row_ptr[elem_index]; - if constexpr (need_fp4_quant) + if constexpr (need_nvfp4_quant || need_mxfp8_quant) { - auto res = quantizePackedFP4Value(in_vec, global_scale_val, - num_tokens_before_expert, expert, permuted_row, elem_index, cols, num_rows, fc1_act_sf_flat, - TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4); + auto res = quantizePackedFPXValue( + in_vec, global_scale_val, num_tokens_before_expert, expert, permuted_row, elem_index, + padded_hidden_size, fc1_act_sf_flat, + is_nvfp4 ? TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4 + : TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX); + static_assert( + sizeof(res) == sizeof(*dest_row_ptr), "Quantized value must be the same size as the output"); dest_row_ptr[elem_index] = res; } else { assert(act_scale_idx == 0 && "Cannot use per-expert act scale for pre-quantized activations"); - writeSF(num_tokens_before_expert, expert, source_row, permuted_row, elem_index, cols, num_rows, - fc1_act_sf_flat, input_sf); - dest_row_ptr[elem_index] = reinterpret_cast(in_vec); + writeSF(num_tokens_before_expert, expert, source_row, permuted_row, + elem_index, padded_hidden_size, fc1_act_sf_flat, input_sf); + dest_row_ptr[elem_index] = in_vec; } } + + // Pad zeros in the extra SFs along the K dimension, we do this to ensure there are no nan values in the + // padded SF atom Use VecSize per thread since we are just writing out zeros so every thread can process a + // whole vector + size_t padding_start_offset = hidden_size / VecSize + start_offset; + size_t padding_elems_in_col = padded_hidden_size / VecSize; + for (int64_t elem_index = padding_start_offset; elem_index < padding_elems_in_col; elem_index += stride) + { + writeSF(num_tokens_before_expert, expert, /*source_row*/ -1, permuted_row, elem_index, + padded_hidden_size, fc1_act_sf_flat, + /* input_sf */ nullptr); // Pass nulltpr input_sf so we write 0 + } } else if constexpr (PRE_QUANT_AWQ) { - using InputElem = cutlass::Array; - using OutputElem_ = cutlass::Array; - using OutputElem_AWQ = std::conditional_t; - auto const* source_row_ptr_awq - = reinterpret_cast(unpermuted_input + source_row * cols / ELEM_PER_BYTE); - auto* dest_row_ptr_awq - = reinterpret_cast(permuted_output) + permuted_row * cols / ELEM_PER_THREAD; - cutlass::NumericArrayConverter converter; - InputElem frag_elems; - + static_assert(!is_nvfp4 && !is_mxfp8, "NVFP4 and MXFP8 are not supported for AWQ"); + static_assert(!std::is_same_v, + "Input and output types must be different for AWQ"); for (int elem_index = start_offset; elem_index < num_elems_in_col; elem_index += stride) { - frag_elems = source_row_ptr_awq[elem_index]; + auto frag_elems = source_row_ptr[elem_index]; CUTLASS_PRAGMA_UNROLL for (int e = 0; e < ELEM_PER_THREAD; e++) @@ -1526,7 +1587,7 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp frag_elems[e] = frag_elems[e] * prequant_scales[elem_index * ELEM_PER_THREAD + e]; } - dest_row_ptr_awq[elem_index] = converter(frag_elems); + dest_row_ptr[elem_index] = arrayConvert(frag_elems); } } else @@ -1543,43 +1604,121 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp permuted_scales[permuted_row] = unpermuted_scales ? unpermuted_scales[source_k_idx] : 1.0f; } } + #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.launch_dependents;"); #endif + + // Pad zeros in the extra SFs along the N dimension, we do this to ensure there are no nan values in the padded SF + // atom + if constexpr (is_nvfp4 || is_mxfp8) + { + int64_t const start_offset = threadIdx.x; + int64_t const stride = EXPAND_THREADS_PER_BLOCK; + // Use VecSize per thread since we are just writing out zeros so every thread can process a whole vector + int64_t const padded_num_elems_in_col = padded_hidden_size / VecSize; + assert(padded_hidden_size % VecSize == 0); + + constexpr int min_num_tokens_alignment = is_nvfp4 ? TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentNVFP4 + : TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX; + static_assert((min_num_tokens_alignment & (min_num_tokens_alignment - 1)) == 0, + "Min num tokens alignment must be a power of two"); + // Since we don't know a priori how much padding is needed we assume the max per expert + // NOTE: we don't use (min_num_tokens_alignment-1) to be able to do power of two divisions + int64_t num_padding_tokens = min_num_tokens_alignment * num_experts_per_node; + + for (int64_t padding_token = blockIdx.x; padding_token < num_padding_tokens; padding_token += gridDim.x) + { + int64_t expert = padding_token / min_num_tokens_alignment; + int64_t num_tokens_before_expert = expert_first_token_offset[expert]; + int64_t num_tokens_after_expert = expert_first_token_offset[expert + 1]; + int64_t tokens_to_expert = num_tokens_after_expert - num_tokens_before_expert; + int64_t padding_to_expert + = TmaWarpSpecializedGroupedGemmInput::alignToSfDim(tokens_to_expert, min_num_tokens_alignment) + - tokens_to_expert; + int64_t expert_pad_idx = padding_token % min_num_tokens_alignment; + if (expert_pad_idx < padding_to_expert) + { + for (int64_t elem_index = start_offset; elem_index < padded_num_elems_in_col; elem_index += stride) + { + writeSF(num_tokens_before_expert, expert, /*source_row*/ -1, + num_tokens_after_expert + expert_pad_idx, elem_index, padded_hidden_size, fc1_act_sf_flat, + /* input_sf */ nullptr); // Pass nulltpr input_sf so we write 0 + } + } + } + } } -template +template void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input, ExpandedActivationsType* permuted_output, float const* unpermuted_scales, float* permuted_scales, - int const* permuted_row_to_unpermuted_row, int64_t const num_rows, int64_t const cols, int const k, - int const num_experts_per_node, float const* fc1_act_global_scale, bool use_per_expert_act_scale, + int const* permuted_row_to_unpermuted_row, int64_t const num_rows, int64_t const hidden_size, int const k, + int const num_experts_per_node, QuantParams const& quant_params, bool use_per_expert_act_scale, int64_t* expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, cudaStream_t stream, - void const* prequant_scales = nullptr) + TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, void const* prequant_scales, cudaStream_t stream) { #ifdef ENABLE_FP4 - // TODO Currently this is a bit hacky because we assume we are in FP8_MXFP4 mode if activations are FP8. - // This code is still needed if we add MXFP8_MXFP4 mode. - // TODO This is also wasteful, we should solve this properly by properly writing the padding in the kernel - if (fc1_act_sf_flat && std::is_same_v) - { - size_t num_elems = getOffsetActivationSF(num_experts_per_node, num_rows * std::min(k, num_experts_per_node), - cols, TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4); - check_cuda_error(cudaMemsetAsync( - fc1_act_sf_flat, 0x0, num_elems * sizeof(TmaWarpSpecializedGroupedGemmInput::NVFP4ElementSF), stream)); - } - else - { - TLLM_CHECK_WITH_INFO( - !use_per_expert_act_scale, "Per-expert act scale for FC1 is only supported for FP4 activations"); - } + TLLM_CHECK_WITH_INFO( + (std::is_same_v && fc1_act_sf_flat) || !use_per_expert_act_scale, + "Per-expert act scale for FC1 is only supported for NVFP4 activations"); + constexpr int64_t min_num_tokens_alignment = std::is_same_v + ? TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentNVFP4 + : TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX; + int64_t num_padding_tokens = min_num_tokens_alignment * num_experts_per_node; +#else + int64_t num_padding_tokens = 0; #endif - static int const smCount = tensorrt_llm::common::getMultiProcessorCount(); + static int64_t const smCount = tensorrt_llm::common::getMultiProcessorCount(); // Note: Launching 8 blocks per SM can fully leverage the memory bandwidth (tested on B200). - int64_t const blocks = smCount * 8; + int64_t const blocks = std::min(smCount * 8, std::max(num_rows * k, num_padding_tokens)); int64_t const threads = EXPAND_THREADS_PER_BLOCK; - auto func = expandInputRowsKernel; + + auto func = [&]() + { +#ifdef ENABLE_FP8 + // Always MXFP8 + if constexpr (std::is_same_v + && !std::is_same_v) + { + TLLM_CHECK_WITH_INFO(quant_params.mxfp8_mxfp4.fc1.weight_block_scale || prequant_scales, + "MXFP8xMXFP4 block scaling or prequant_scales or prequant_scales parameters not provided"); + return prequant_scales ? &expandInputRowsKernel + : &expandInputRowsKernel; + } + // Could be either regular FP8 or MXFP8 + else if constexpr (std::is_same_v + && std::is_same_v) + { + TLLM_CHECK_WITH_INFO(!prequant_scales, "NVFP4 is not supported for AWQ"); + return quant_params.mxfp8_mxfp4.fc1.weight_block_scale + ? &expandInputRowsKernel + : &expandInputRowsKernel; + } + else +#endif +#ifdef ENABLE_FP4 + if constexpr (std::is_same_v) + { + TLLM_CHECK_WITH_INFO( + quant_params.fp4.fc1.weight_block_scale, "NVFP4 block scaling is expected for FP4xFP4"); + TLLM_CHECK_WITH_INFO(!prequant_scales, "NVFP4 is not supported for AWQ"); + return &expandInputRowsKernel; + } + else +#endif + { + TLLM_CHECK_WITH_INFO(!prequant_scales, "w4afp8 Prequant scales provided for non-FP8 data type"); + return &expandInputRowsKernel; + } + }(); cudaLaunchConfig_t config; config.gridDim = blocks; @@ -1592,23 +1731,24 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input, config.numAttrs = 1; config.attrs = attrs; cudaLaunchKernelEx(&config, func, unpermuted_input, permuted_output, unpermuted_scales, permuted_scales, - permuted_row_to_unpermuted_row, num_rows, cols, k, fc1_act_global_scale, use_per_expert_act_scale, - expert_first_token_offset, fc1_act_sf_flat, input_sf, num_experts_per_node, + permuted_row_to_unpermuted_row, num_rows, hidden_size, k, quant_params.fp4.fc1.act_global_scale, + use_per_expert_act_scale, expert_first_token_offset, fc1_act_sf_flat, input_sf, num_experts_per_node, reinterpret_cast(prequant_scales)); } #define INSTANTIATE_EXPAND_INPUT_ROWS(InputActivationsType, ExpandedActivationsType) \ - template void expandInputRowsKernelLauncher( \ + template void expandInputRowsKernelLauncher( \ InputActivationsType const* unpermuted_input, ExpandedActivationsType* permuted_output, \ float const* unpermuted_scales, float* permuted_scales, int const* permuted_row_to_unpermuted_row, \ - int64_t const num_rows, int64_t const cols, int const k, int const num_experts_per_node, \ - float const* fc1_act_global_scale, bool use_per_expert_act_scale, int64_t* expert_first_token_offset, \ + int64_t const num_rows, int64_t const hidden_size, int const k, int const num_experts_per_node, \ + QuantParams const& quant_params, bool use_per_expert_act_scale, int64_t* expert_first_token_offset, \ TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat, \ - TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, cudaStream_t stream, \ - void const* prequant_scales); + TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, void const* prequant_scales, \ + cudaStream_t stream) -INSTANTIATE_EXPAND_INPUT_ROWS(half, half); +// Instantiate the data types that are used by the external pytorch op INSTANTIATE_EXPAND_INPUT_ROWS(float, float); +INSTANTIATE_EXPAND_INPUT_ROWS(half, half); #ifdef ENABLE_BF16 INSTANTIATE_EXPAND_INPUT_ROWS(__nv_bfloat16, __nv_bfloat16); #endif @@ -1679,7 +1819,6 @@ __global__ void finalizeMoeRoutingKernel(GemmOutputType const* expanded_permuted ComputeElem expert_result = arrayConvert(expanded_permuted_rows_row_ptr[elem_index]); - if (bias) { auto const* bias_ptr = bias_v + expert_id * num_elems_in_col; @@ -1861,6 +2000,7 @@ void finalizeMoeRoutingKernelLauncher(GemmOutputType const* expanded_permuted_ro int64_t const experts_per_token, int64_t const num_experts_per_node, MOEParallelismConfig parallelism_config, \ bool const enable_alltoall, cudaStream_t stream); +// Instantiate the data types that are used by the external pytorch op INSTANTIATE_FINALIZE_MOE_ROUTING(half, half, half); INSTANTIATE_FINALIZE_MOE_ROUTING(float, float, float); #ifdef ENABLE_BF16 @@ -1924,17 +2064,21 @@ void doGatedActivation(ActivationOutputType* output, GemmOutputType const* gemm_ // ============================== Activation ================================= -template class ActFn> +template class ActFn, + TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType BlockScalingType> __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, float const* fp8_quant, ScaleBiasType const* bias_ptr, bool bias_is_broadcast, int64_t const* expert_first_token_offset, - int num_experts_per_node, int64_t inter_size, int64_t max_tokens_per_expert, bool gated, - float const* fc2_act_global_scale, bool use_per_expert_act_scale, - TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_act_sf_flat) + int num_experts_per_node, int64_t inter_size, bool gated, float const* fc2_act_global_scale, + bool use_per_expert_act_scale, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_act_sf_flat) { #ifdef ENABLE_FP4 - constexpr bool IsFP4 = std::is_same_v; + constexpr bool IsNVFP4 = std::is_same_v + && BlockScalingType == TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4; + constexpr bool IsMXFP8 = std::is_same_v + && BlockScalingType == TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX; #else - constexpr bool IsFP4 = cute::dependent_false; + constexpr bool IsNVFP4 = cute::dependent_false; + constexpr bool IsMXFP8 = cute::dependent_false; #endif int64_t const tid = threadIdx.x; @@ -1945,6 +2089,19 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, asm volatile("griddepcontrol.wait;"); #endif + constexpr int64_t VecSize = IsNVFP4 ? TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize + : TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize; + // Load 128-bits per thread, according to the smallest data type we read/write + constexpr int64_t ACTIVATION_ELEM_PER_THREAD = (IsNVFP4 || IsMXFP8) + ? CVT_FP4_ELTS_PER_THREAD + : (128 / std::min(sizeof_bits::value, sizeof_bits::value)); + + // This should be VecSize * 4 elements + // We assume at least VecSize alignment or the quantization will fail + int64_t const min_k_dim_alignment = IsNVFP4 ? TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentNVFP4 + : TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentMXFPX; + int64_t const padded_inter_size = ceilDiv(inter_size, min_k_dim_alignment) * min_k_dim_alignment; + int64_t const num_valid_tokens = expert_first_token_offset[num_experts_per_node]; for (int64_t token = blockIdx.x; token < num_valid_tokens; token += gridDim.x) @@ -1953,7 +2110,7 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, size_t output_offset = token * inter_size; int64_t expert = 0; - if (bias_ptr || IsFP4 || use_per_expert_act_scale) + if (bias_ptr || IsNVFP4 || IsMXFP8 || use_per_expert_act_scale) { // TODO this is almost certainly faster as a linear scan expert = findTotalEltsLessThanTarget(expert_first_token_offset, num_experts_per_node, token + 1) - 1; @@ -1964,7 +2121,7 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, // Some globals for FP4 float global_scale_val = fc2_act_global_scale ? fc2_act_global_scale[act_scale_idx] : 1.0f; - int64_t num_tokens_before_expert = IsFP4 ? expert_first_token_offset[expert] : 0; + int64_t num_tokens_before_expert = (IsNVFP4 || IsMXFP8) ? expert_first_token_offset[expert] : 0; size_t bias_offset = 0; if (bias_ptr) @@ -1972,14 +2129,10 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, bias_offset = (bias_is_broadcast ? expert * inter_size * gated_size_mul : gemm_result_offset); } - // Load 128-bits per thread, according to the smallest data type we read/write - constexpr int64_t ACTIVATION_ELEM_PER_THREAD = IsFP4 - ? CVT_FP4_ELTS_PER_THREAD - : (128 / std::min(sizeof_bits::value, sizeof_bits::value)); - using BiasElem = cutlass::Array; using GemmResultElem = cutlass::Array; - using OutputElem = std::conditional_t>; + using OutputElem = std::conditional_t>>; using ComputeElem = cutlass::Array; // Aliases gemm_result for non-gated, non-fp8 cases auto gemm_result_vec = reinterpret_cast(gemm_result + gemm_result_offset); @@ -2015,12 +2168,15 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, auto post_act_val = gate_act * quant_scale; - if constexpr (IsFP4) + if constexpr (IsNVFP4 || IsMXFP8) { // We use GemmOutputType as the intermediate compute type as that should always be unquantized - auto res = quantizePackedFP4Value(post_act_val, global_scale_val, - num_tokens_before_expert, expert, token, elem_index, inter_size, max_tokens_per_expert, - fc2_act_sf_flat, TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4); + auto res = quantizePackedFPXValue(post_act_val, + global_scale_val, num_tokens_before_expert, expert, token, elem_index, inter_size, fc2_act_sf_flat, + IsNVFP4 ? TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4 + : TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX); + static_assert( + sizeof(res) == sizeof(*output_vec), "Quantized value must be the same size as the output"); output_vec[elem_index] = res; } else @@ -2028,33 +2184,135 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result, output_vec[elem_index] = arrayConvert(post_act_val); } } + + // Pad zeros in the extra SFs along the K dimension, we do this to ensure there are no nan values in the padded + // SF atom + if constexpr (IsNVFP4 || IsMXFP8) + { + // Use VecSize per thread since we are just writing out zeros so every thread can process a whole vector + size_t padding_start_offset = inter_size / VecSize + start_offset; + size_t padding_elems_in_col = padded_inter_size / VecSize; + for (int64_t elem_index = padding_start_offset; elem_index < padding_elems_in_col; elem_index += stride) + { + writeSF(num_tokens_before_expert, expert, /*source_row*/ -1, token, elem_index, + padded_inter_size, fc2_act_sf_flat, /* input_sf */ nullptr); // Pass nulltpr input_sf so we write 0 + } + } } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) asm volatile("griddepcontrol.launch_dependents;"); #endif + + // Pad zeros in the extra SFs along the N dimension, we do this to ensure there are no nan values in the padded SF + // atom + if constexpr (IsNVFP4 || IsMXFP8) + { + int64_t const start_offset = threadIdx.x; + int64_t const stride = ACTIVATION_THREADS_PER_BLOCK; + // Use VecSize per thread since we are just writing out zeros so every thread can process a whole vector + int64_t const padded_num_elems_in_col = padded_inter_size / VecSize; + assert(padded_inter_size % VecSize == 0); + + constexpr int64_t min_num_tokens_alignment = IsNVFP4 + ? TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentNVFP4 + : TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX; + static_assert((min_num_tokens_alignment & (min_num_tokens_alignment - 1)) == 0, + "Min num tokens alignment must be a power of two"); + // Since we don't know a priori how much padding is needed we assume the max per expert + // NOTE: we don't (min_num_tokens_alignment-1) to have power of two divisions + int64_t num_padding_tokens = min_num_tokens_alignment * num_experts_per_node; + + for (int64_t padding_token = blockIdx.x; padding_token < num_padding_tokens; padding_token += gridDim.x) + { + int64_t expert = padding_token / min_num_tokens_alignment; + int64_t num_tokens_before_expert = expert_first_token_offset[expert]; + int64_t num_tokens_after_expert = expert_first_token_offset[expert + 1]; + int64_t tokens_to_expert = num_tokens_after_expert - num_tokens_before_expert; + int64_t padding_to_expert + = TmaWarpSpecializedGroupedGemmInput::alignToSfDim(tokens_to_expert, min_num_tokens_alignment) + - tokens_to_expert; + int64_t expert_pad_idx = padding_token % min_num_tokens_alignment; + if (expert_pad_idx < padding_to_expert) + { + for (int64_t elem_index = start_offset; elem_index < padded_num_elems_in_col; elem_index += stride) + { + // The SF buffer is padded to a multiple of MinNDimAlignment for each expert + // This means we can safely write to offset num_tokens_after_expert + padded_token, since the next + // expert will leave space for the padding + writeSF(num_tokens_before_expert, expert, /*source_row*/ -1, + num_tokens_after_expert + expert_pad_idx, elem_index, padded_inter_size, fc2_act_sf_flat, + /* input_sf */ nullptr); // Pass nulltpr input_sf so we write 0 + } + } + } + } } template void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8_quant, ScaleBiasType const* bias, bool bias_is_broadcast, int64_t const* expert_first_token_offset, int num_experts_per_node, int64_t inter_size, - int64_t num_tokens, int64_t expanded_num_tokens, ActivationType activation_type, float const* fc2_act_global_scale, + int64_t expanded_num_tokens, ActivationType activation_type, QuantParams const& quant_params, bool use_per_expert_act_scale, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_act_sf_flat, cudaStream_t stream) { - static int const smCount = tensorrt_llm::common::getMultiProcessorCount(); + +#ifdef ENABLE_FP4 + constexpr int64_t min_num_tokens_alignment = std::is_same_v + ? TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentNVFP4 + : TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX; + int64_t num_padding_tokens = min_num_tokens_alignment * num_experts_per_node; +#else + int64_t num_padding_tokens = 0; +#endif + + static int64_t const smCount = tensorrt_llm::common::getMultiProcessorCount(); // Note: Launching 8 blocks per SM can fully leverage the memory bandwidth (tested on B200). - int64_t const blocks = smCount * 8; + int64_t const blocks = std::min(smCount * 8, std::max(expanded_num_tokens, num_padding_tokens)); int64_t const threads = ACTIVATION_THREADS_PER_BLOCK; - auto fn_list = std::array{ - &doActivationKernel, // Gelu - &doActivationKernel, // Relu - &doActivationKernel, // Silu - &doActivationKernel, // Swiglu - &doActivationKernel, // Geglu - &doActivationKernel // Identity - }; - auto fn = fn_list[static_cast(activation_type)]; + auto fn = [&]() + { + auto fn = [&](auto block_scaling_type) + { + auto fn_list = std::array{ + &doActivationKernel, // Gelu + &doActivationKernel, // Relu + &doActivationKernel, // Silu + &doActivationKernel, // Swiglu + &doActivationKernel, // Geglu + &doActivationKernel // Identity + }; + return fn_list[static_cast(activation_type)]; + }; + auto NVFP4 = tensorrt_llm::common::ConstExprWrapper{}; + auto MXFPX = tensorrt_llm::common::ConstExprWrapper{}; + auto NONE = tensorrt_llm::common::ConstExprWrapper{}; +#ifdef ENABLE_FP4 + if constexpr (std::is_same_v) + { + TLLM_CHECK_WITH_INFO( + quant_params.fp4.fc2.weight_block_scale, "NVFP4 block scaling is expected for FP4xFP4"); + return fn(NVFP4); + } + else if constexpr (std::is_same_v) + { + return quant_params.mxfp8_mxfp4.fc2.weight_block_scale ? fn(MXFPX) : fn(NONE); + } + else +#endif + { + return fn(NONE); + } + }(); cudaLaunchConfig_t config; config.gridDim = blocks; @@ -2067,7 +2325,7 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8 config.numAttrs = 1; config.attrs = attrs; cudaLaunchKernelEx(&config, fn, output, gemm_result, fp8_quant, bias, bias_is_broadcast, expert_first_token_offset, - num_experts_per_node, inter_size, num_tokens, isGatedActivation(activation_type), fc2_act_global_scale, + num_experts_per_node, inter_size, isGatedActivation(activation_type), quant_params.fp4.fc2.act_global_scale, use_per_expert_act_scale, fc2_act_sf_flat); } @@ -2309,6 +2567,8 @@ CutlassMoeFCRunner:: ? 0 : num_moe_inputs * hidden_size * gemm_output_dtype; // May be an intermediate type for quantization + // If topk is greater than num_experts_per_node (i.e. large EP value), then we don't need to allocate for the whole + // tokens*topk auto act_sf_rows = min_latency_mode ? num_moe_inputs : std::min(num_moe_inputs, static_cast(num_rows * num_experts_per_node)); @@ -2582,8 +2842,7 @@ void CutlassMoeFCRunner(output, static_cast(gemm_output), fc2_fp8_quant, fc1_expert_biases, bias_is_broadcast, expert_first_token_offset, num_experts_per_node, - inter_size, num_rows, expanded_num_rows, fc1_activation_type, nullptr, use_per_expert_act_scale, nullptr, - stream); + inter_size, expanded_num_rows, fc1_activation_type, quant_params, use_per_expert_act_scale, nullptr, stream); sync_check_cuda_error(stream); } @@ -2717,20 +2976,6 @@ void CutlassMoeFCRunner; bool use_per_expert_act_scale = use_fp4 ? quant_params.fp4.fc2.use_per_expert_act_scale @@ -2740,9 +2985,8 @@ void CutlassMoeFCRunner(reinterpret_cast(output), static_cast(gemm_output), fc2_fp8_quant, fc1_expert_biases, bias_is_broadcast, - expert_first_token_offset, num_experts_per_node, inter_size, num_rows, expanded_num_rows, - fc1_activation_type, quant_params.fp4.fc2.act_global_scale, use_per_expert_act_scale, fc2_fp4_act_flat, - stream); + expert_first_token_offset, num_experts_per_node, inter_size, expanded_num_rows, fc1_activation_type, + quant_params, use_per_expert_act_scale, fc2_fp4_act_flat, stream); sync_check_cuda_error(stream); } @@ -2766,7 +3010,7 @@ void CutlassMoeFCRunner(output, static_cast(intermediate_result), fc2_fp8_quant, fc1_expert_biases, bias_is_broadcast, expert_first_token_offset, num_experts_per_node, - inter_size, num_rows, expanded_num_rows, fc1_activation_type, nullptr, use_per_expert_act_scale, nullptr, + inter_size, expanded_num_rows, fc1_activation_type, quant_params, use_per_expert_act_scale, nullptr, stream); sync_check_cuda_error(stream); @@ -2930,7 +3174,7 @@ void CutlassMoeFCRunner(gemm_output), static_cast(gemm_output), nullptr, static_cast(fc2_lora), false, expert_first_token_offset, num_experts_per_node, - hidden_size, num_rows, expanded_num_rows, ActivationType::Identity, nullptr, false, nullptr, stream); + hidden_size, expanded_num_rows, ActivationType::Identity, {}, false, nullptr, stream); sync_check_cuda_error(stream); } @@ -3208,17 +3452,25 @@ void CutlassMoeFCRunner::value) == 0, - "Hidden size does not meet minimum alignment requirements for MOE GEMM"); - // Require at least 64 bytes of alignment for MOE GEMM - TLLM_CHECK_WITH_INFO(inter_size % (128 / sizeof_bits::value) == 0, - "Inter size does not meet minimum alignment requirements for MOE GEMM"); - if (weight_fp4) + + if (quant_params.mxfp8_mxfp4.fc1.weight_block_scale) { - TLLM_CHECK_WITH_INFO( - hidden_size % 128 == 0, "Hidden size does not meet minimum alignment requirements for MOE GEMM"); - TLLM_CHECK_WITH_INFO( - inter_size % 128 == 0, "Inter size does not meet minimum alignment requirements for MOE GEMM"); + TLLM_CHECK_WITH_INFO(hidden_size % (64 * 8 / sizeof_bits::value) == 0, + "Hidden size %d does not meet minimum alignment requirements for MXFP8_MXFP4 MOE GEMM %d", + (int) hidden_size, (int) (64 * 8 / sizeof_bits::value)); + TLLM_CHECK_WITH_INFO(inter_size % (64 * 8 / sizeof_bits::value) == 0, + "Inter size %d does not meet minimum alignment requirements for MXFP8_MXFP4 MOE GEMM %d", (int) inter_size, + (int) (64 * 8 / sizeof_bits::value)); + } + else + { + // Require at least 128 bits of alignment for MOE GEMM + TLLM_CHECK_WITH_INFO(hidden_size % (128 / sizeof_bits::value) == 0, + "Hidden size %d does not meet minimum alignment requirements for MOE GEMM %d", (int) hidden_size, + (int) (128 / sizeof_bits::value)); + TLLM_CHECK_WITH_INFO(inter_size % (128 / sizeof_bits::value) == 0, + "Inter size %d does not meet minimum alignment requirements for MOE GEMM %d", (int) inter_size, + (int) (128 / sizeof_bits::value)); } // These values must fit into an int for building the source maps @@ -3367,29 +3619,14 @@ void CutlassMoeFCRunner; // Only NVFP4xNVFP4 supports FC1 per-expert act scale bool use_per_expert_act_scale = use_fp4 ? quant_params.fp4.fc1.use_per_expert_act_scale : false; - T const* gemm1_input; - if constexpr (use_w4afp8) - { - // FP16/BF16 input_activations -> FP8 smoothed_act - expandInputRowsKernelLauncher(input_activations, reinterpret_cast(smoothed_act_), - token_topk_unpermuted_scales, permuted_token_final_scales_, permuted_row_to_unpermuted_row_, num_rows, - hidden_size, experts_per_token, num_experts_per_node, quant_params.fp4.fc1.act_global_scale, - use_per_expert_act_scale, expert_first_token_offset_, fc1_fp4_act_scale_, input_sf, stream, - quant_params.groupwise.fc1.act_scales); - - gemm1_input = reinterpret_cast(smoothed_act_); - } - else - { - expandInputRowsKernelLauncher(input_activations, reinterpret_cast(permuted_data_), - token_topk_unpermuted_scales, permuted_token_final_scales_, permuted_row_to_unpermuted_row_, num_rows, - hidden_size, experts_per_token, num_experts_per_node, quant_params.fp4.fc1.act_global_scale, - use_per_expert_act_scale, expert_first_token_offset_, fc1_fp4_act_scale_, input_sf, stream); - gemm1_input = reinterpret_cast(permuted_data_); - } + T* gemm1_input_expand = use_w4afp8 ? reinterpret_cast(smoothed_act_) : reinterpret_cast(permuted_data_); + expandInputRowsKernelLauncher(input_activations, gemm1_input_expand, token_topk_unpermuted_scales, + permuted_token_final_scales_, permuted_row_to_unpermuted_row_, num_rows, hidden_size, experts_per_token, + num_experts_per_node, quant_params, use_per_expert_act_scale, expert_first_token_offset_, + fc1_fp4_act_scale_, input_sf, use_w4afp8 ? quant_params.groupwise.fc1.act_scales : nullptr, stream); + auto const* gemm1_input = gemm1_input_expand; sync_check_cuda_error(stream); @@ -3664,12 +3901,6 @@ CutlassMoeFCRunner:: = std::max(fc1_sf_offset, fc2_sf_offset) * sizeof(TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF); check_cuda_error(cudaMemsetAsync(fc1_fp4_act_scale_, weight_block_scale_value_int, max_size, stream)); } - else - { - // TODO This will involve updating the expandInputRowsKernel and doActivationKernel to support MXFP8 - // quantization - TLLM_CHECK_WITH_INFO(!use_wfp4afp8, "WFP4AFP8 with true MXFP8 weights is not implemented yet"); - } TLLM_CHECK_WITH_INFO(gemm1_input != gemm1_output, "Input and output buffers are overlapping"); return Self::computeStridesTmaWarpSpecialized(expert_first_token_offset_, gemm1_tma_ws_input, @@ -4440,10 +4671,12 @@ template class CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, __nv_bfloat1 template class CutlassMoeFCRunner<__nv_fp4_e2m1, __nv_fp4_e2m1, half>; template class CutlassMoeFCRunner<__nv_fp4_e2m1, __nv_fp4_e2m1, half, half>; template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp4_e2m1, half>; +template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp4_e2m1, half, half>; #ifdef ENABLE_BF16 template class CutlassMoeFCRunner<__nv_fp4_e2m1, __nv_fp4_e2m1, __nv_bfloat16>; template class CutlassMoeFCRunner<__nv_fp4_e2m1, __nv_fp4_e2m1, __nv_bfloat16, __nv_bfloat16>; template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp4_e2m1, __nv_bfloat16>; +template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp4_e2m1, __nv_bfloat16, __nv_bfloat16>; #endif #endif diff --git a/cpp/tensorrt_llm/kernels/quantization.cuh b/cpp/tensorrt_llm/kernels/quantization.cuh index cbb89579eda..95768fecfee 100644 --- a/cpp/tensorrt_llm/kernels/quantization.cuh +++ b/cpp/tensorrt_llm/kernels/quantization.cuh @@ -398,6 +398,95 @@ struct PackedVec<__nv_fp8_e4m3> "Vector size should match the number of elements per thread."); }; +// Convert 4 float2 values into 8 e4m3 values (represented as one uint64_t). +inline __device__ uint64_t fp32_vec_to_e4m3(float2 (&array)[4]) +{ + union + { + uint64_t val; + __nv_fp8x2_e4m3 elts[4]; + } u; + + static_assert(sizeof(u.val) == sizeof(u.elts), "Expected to alias uint64_t and __nv_fp8x2_e4m3[4]"); + + u.elts[0] = __nv_fp8x2_e4m3(array[0]); + u.elts[1] = __nv_fp8x2_e4m3(array[1]); + u.elts[2] = __nv_fp8x2_e4m3(array[2]); + u.elts[3] = __nv_fp8x2_e4m3(array[3]); + return u.val; +} + +// Quantizes the provided PackedVec into the uint64_t output +template +__device__ uint64_t cvt_warp_fp16_to_mxfp8(PackedVec& vec, uint8_t* SFout) +{ +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + // Get absolute maximum values among the local 8 values. + auto localMax = cuda_abs(vec.elts[0]); + +// Local maximum value. +#pragma unroll + for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) + { + localMax = cuda_max(localMax, cuda_abs(vec.elts[i])); + } + + constexpr int CVT_NUM_THREADS_PER_SF = SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD; + // Get the absolute maximum among all 16 values (two threads for 16, four threads for 32). + localMax = cuda_max(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); + if constexpr (CVT_NUM_THREADS_PER_SF == 4) + { + localMax = cuda_max(__shfl_xor_sync(uint32_t(-1), localMax, 2), localMax); + } + // Get the final absolute maximum values. + float vecMax = float(cuda_max(localMax.x, localMax.y)); + + // Get the SF (max value of the vector / max value of mxfp8). + float SFValue = vecMax * reciprocal_approximate_ftz(448.0f); + // 8 bits representation of the SF. + uint8_t fp8SFVal; + // Write the SF to global memory (STG.8). + __nv_fp8_e8m0 tmpSFVal; + tmpSFVal.__x = __nv_cvt_float_to_e8m0(SFValue, __NV_SATFINITE, cudaRoundPosInf); + float SFValueNarrow = static_cast(tmpSFVal); + fp8SFVal = tmpSFVal.__x; + // Get the output scale (reciprocal of the SFValue). + float outputScale = SFValue != 0.f ? reciprocal_approximate_ftz(SFValueNarrow) : 0.0f; + + if (SFout) + { + // Write the SF to global memory (STG.8). + *SFout = fp8SFVal; + } + + // Convert the input to float. + float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; + +#pragma unroll + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) + { + if constexpr (std::is_same_v) + { + fp2Vals[i] = __half22float2(vec.elts[i]); + } + else + { + fp2Vals[i] = __bfloat1622float2(vec.elts[i]); + } + fp2Vals[i].x *= outputScale; + fp2Vals[i].y *= outputScale; + } + + // Convert to e4m3 values. + uint64_t e4m3Vec = fp32_vec_to_e4m3(fp2Vals); + + // Write the e4m3 values to global memory. + return e4m3Vec; +#else + return 0; +#endif +} + // Quantizes the provided PackedVec into the uint32_t output template __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec& vec, float SFScaleVal, uint8_t* SFout) @@ -427,6 +516,7 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec& vec, float SFScaleVal, // maximum value of e2m1 = 6.0. // TODO: use half as compute data type. float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); + float SFValueNarrow; // 8 bits representation of the SF. uint8_t fp8SFVal; // Write the SF to global memory (STG.8). @@ -434,7 +524,7 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec& vec, float SFScaleVal, { __nv_fp8_e8m0 tmp; tmp.__x = __nv_cvt_float_to_e8m0(SFValue, __NV_SATFINITE, cudaRoundPosInf); - SFValue = static_cast(tmp); + SFValueNarrow = static_cast(tmp); fp8SFVal = tmp.__x; } else @@ -442,12 +532,12 @@ __device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec& vec, float SFScaleVal, // Here SFValue is always positive, so E4M3 is the same as UE4M3. __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); fp8SFVal = tmp.__x; - SFValue = static_cast(tmp); + SFValueNarrow = static_cast(tmp); } // Get the output scale. // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * reciprocal(SFScaleVal)) float outputScale - = SFValue != 0 ? reciprocal_approximate_ftz(SFValue * reciprocal_approximate_ftz(SFScaleVal)) : 0.0f; + = SFValue != 0 ? reciprocal_approximate_ftz(SFValueNarrow * reciprocal_approximate_ftz(SFScaleVal)) : 0.0f; if (SFout) { @@ -573,7 +663,7 @@ __device__ uint64_t cvt_warp_fp8_to_fp4(PackedVec& vec, float SFScaleVal, } template -inline __device__ int64_t get_sf_out_offset_128x4( +inline __device__ __host__ int64_t get_sf_out_offset_128x4( std::optional batchIdx, int mIdx, int kIdx, std::optional numRows, int numCols) { // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] diff --git a/cpp/tensorrt_llm/thop/CMakeLists.txt b/cpp/tensorrt_llm/thop/CMakeLists.txt index d2c196c604a..de626edca73 100644 --- a/cpp/tensorrt_llm/thop/CMakeLists.txt +++ b/cpp/tensorrt_llm/thop/CMakeLists.txt @@ -34,6 +34,7 @@ set_property(TARGET th_utils PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) target_link_libraries(th_utils PUBLIC ${TORCH_LIBRARIES} ${CUBLAS_LIB} ${CURAND_LIB}) +# TODO This does not compile with internal cutlass MOE gemm add_library( th_common SHARED mlaPreprocessOp.cpp diff --git a/cpp/tensorrt_llm/thop/moeOp.cpp b/cpp/tensorrt_llm/thop/moeOp.cpp index b3f9ef876e5..1bc2a057ed9 100644 --- a/cpp/tensorrt_llm/thop/moeOp.cpp +++ b/cpp/tensorrt_llm/thop/moeOp.cpp @@ -15,12 +15,13 @@ */ #if defined(USING_OSS_CUTLASS_MOE_GEMM) -#include "tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h" #include "tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h" #else #include "moe_gemm_kernels.h" #include "moe_kernels.h" #endif +// Always include the public header for moe_gemm_kernels.h +#include "tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h" #include "tensorrt_llm/common/workspace.h" #include "tensorrt_llm/kernels/cutlass_kernels/fp8_blockscale_gemm/fp8_blockscale_gemm.h" @@ -46,7 +47,8 @@ namespace torch_ext namespace common = tensorrt_llm::common; namespace kernels = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE; using ActivationType = CUTLASS_MOE_GEMM_NAMESPACE::ActivationType; -using TmaWarpSpecializedGroupedGemmInput = CUTLASS_MOE_GEMM_NAMESPACE::TmaWarpSpecializedGroupedGemmInput; +// Always use public header as it is just utility functions and types +using TmaWarpSpecializedGroupedGemmInput = tensorrt_llm::kernels::cutlass_kernels::TmaWarpSpecializedGroupedGemmInput; using profiler_backend = CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::GemmProfilerBackend; class FusedMoeRunner : public torch::CustomClassHolder @@ -92,13 +94,14 @@ class FusedMoeRunner : public torch::CustomClassHolder }; FusedMoeRunner(c10::ScalarType activation_dtype, c10::ScalarType weight_dtype, c10::ScalarType output_dtype, - bool use_deepseek_fp8_block_scale, bool use_w4a8_group_scaling) + bool use_deepseek_fp8_block_scale, bool use_w4a8_group_scaling, bool use_mxfp8_act_scaling) { mActivationDtype = activation_dtype; mWeightDtype = weight_dtype; mOutputDtype = output_dtype; mUseDeepSeekFP8BlockScaling = use_deepseek_fp8_block_scale; mUseW4A8GroupScaling = use_w4a8_group_scaling; + mUseMxfp8ActScaling = use_mxfp8_act_scaling; mInnerDimMultiplier = 1; // keep consistent with cpp/tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.cpp @@ -126,7 +129,7 @@ class FusedMoeRunner : public torch::CustomClassHolder } #endif #ifdef ENABLE_FP4 - if (isWFp4AFp8Quant()) + if (isWMxfp4AMxfp8Quant() || isWMxfp4AFp8Quant()) { mInnerDimMultiplier = 16; // 16 FP4 -> 1 LONG mKernelRunner = switch_output_type<__nv_fp8_e4m3, __nv_fp4_e2m1>(mOutputDtype); @@ -275,6 +278,24 @@ class FusedMoeRunner : public torch::CustomClassHolder int64_t num_rows = input.sizes()[0]; int64_t hidden_size = fc2_expert_weights.sizes()[1]; int64_t inter_size = fc2_expert_weights.sizes()[2] * mInnerDimMultiplier; + + if (isWMxfp4AMxfp8Quant() || isWMxfp4AFp8Quant()) + { + // MXFP4 weights are required to bealigned to 128 bytes + TORCH_CHECK(hidden_size % 128 == 0, "hidden_size must be divisible by 128 for MXFP4 weights"); + TORCH_CHECK(inter_size % 128 == 0, "inter_size must be divisible by 128 for MXFP4 weights"); + } + else + { + // TMA requires at least 128 bit alignment + auto min_alignment + = 128 / (8 * std::min(c10::elementSize(mActivationDtype), c10::elementSize(mWeightDtype))); + TORCH_CHECK(hidden_size % min_alignment == 0, "hidden_size ", hidden_size, " must be divisible by ", + min_alignment, " for weights"); + TORCH_CHECK(inter_size % min_alignment == 0, "inter_size ", inter_size, " must be divisible by ", + min_alignment, " for weights"); + } + int const num_experts_on_rank = fc2_expert_weights.sizes()[0]; auto const num_experts_total = static_cast(num_experts_on_rank * ep_size); auto parallelism_config = kernels::MOEParallelismConfig(tp_size, tp_rank, ep_size, ep_rank); @@ -388,6 +409,9 @@ class FusedMoeRunner : public torch::CustomClassHolder TORCH_CHECK(fc1_expert_weights.sizes()[1] == fc2_expert_weights.sizes()[2] * mInnerDimMultiplier * 2, "fc1_expert_weights inter size must be 2 times fc2_expert_weights inter size."); + TORCH_CHECK(!input_sf.has_value() || isWMxfp4AMxfp8Quant() || isNvfp4Quant(), + "Block-scaling factors provided for non block-scaling quantization"); + int experts_per_token = token_selected_experts.sizes()[1]; int64_t num_rows = input.sizes()[0]; int64_t hidden_size = fc2_expert_weights.sizes()[1]; @@ -556,6 +580,7 @@ class FusedMoeRunner : public torch::CustomClassHolder bool mUseDeepSeekFP8BlockScaling = false; bool mUseW4A8GroupScaling = false; + bool mUseMxfp8ActScaling = false; using Profile = tensorrt_llm::cutlass_extensions::CutlassGemmConfig; std::vector mAllProfiles; @@ -655,11 +680,10 @@ class FusedMoeRunner : public torch::CustomClassHolder /* fp8 output quant scale */ nullptr, static_cast(fc1_input_dequant.data_ptr()), fc2_quant.dim() == 1); } - - else if (isWFp4AFp8Quant()) + else if (isWMxfp4AFp8Quant()) { - TORCH_CHECK(quant_scales.has_value(), "Expecting quant scales for WFP4AFP8 quantization"); - TORCH_CHECK(quant_scales.value().size() == 5, "Expecting 5 quant scales for WFP4AFP8 quantization"); + TORCH_CHECK(quant_scales.has_value(), "Expecting quant scales for W4A8_MXFP4_MXF8 quantization"); + TORCH_CHECK(quant_scales.value().size() == 5, "Expecting 5 quant scales for W4A8_MXFP4_FP8 quantization"); auto const fc1_weight_block = quant_scales.value()[0]; auto const fc1_global = quant_scales.value()[1]; @@ -684,19 +708,27 @@ class FusedMoeRunner : public torch::CustomClassHolder TORCH_CHECK(fc2_global.dim() == 1, "fc2 global must be 1D"); // Check shapes TORCH_CHECK(fc1_weight_block.sizes()[0] == num_experts_on_rank - && fc1_weight_block.sizes()[1] == inter_size * 2 + && fc1_weight_block.sizes()[1] + == TmaWarpSpecializedGroupedGemmInput::alignToSfDim( + inter_size, TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX) + * 2 && fc1_weight_block.sizes()[2] * FP8_PER_INT32 * TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize - == hidden_size, + == TmaWarpSpecializedGroupedGemmInput::alignToSfDim( + hidden_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentMXFPX), "fc1 weight block size must be (num_experts_on_rank, inter_size * 2, hidden_size // 4 // " "block_scale_vector_size)"); TORCH_CHECK(fc1_global.sizes()[0] == num_experts_on_rank, "fc1 global size must be (num_experts_on_rank,)"); TORCH_CHECK(fc2_act_global.dim() == 0 || fc2_act_global.sizes()[0] == num_experts_on_rank, "fc2 act global must be scalar or (num_experts_on_rank,)"); - TORCH_CHECK(fc2_weight_block.sizes()[0] == num_experts_on_rank && fc2_weight_block.sizes()[1] == hidden_size + TORCH_CHECK(fc2_weight_block.sizes()[0] == num_experts_on_rank + && fc2_weight_block.sizes()[1] + == TmaWarpSpecializedGroupedGemmInput::alignToSfDim( + hidden_size, TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX) && fc2_weight_block.sizes()[2] * FP8_PER_INT32 * TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize - == inter_size, + == TmaWarpSpecializedGroupedGemmInput::alignToSfDim( + inter_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentMXFPX), "fc2 weight block size must be (num_experts_on_rank, hidden_size, inter_size // 4 // " "block_scale_vector_size)"); TORCH_CHECK(fc2_global.sizes()[0] == num_experts_on_rank, "fc2 global size must be (num_experts_on_rank,)"); @@ -707,6 +739,62 @@ class FusedMoeRunner : public torch::CustomClassHolder static_cast(fc2_weight_block.data_ptr()), static_cast(fc2_global.data_ptr()), false, fc2_act_global.dim() == 1); } + else if (isWMxfp4AMxfp8Quant()) + { +#ifdef USING_OSS_CUTLASS_MOE_GEMM + TORCH_CHECK(quant_scales.has_value(), "Expecting quant scales for W4A8_MXFP4_MXFP8 quantization"); + TORCH_CHECK(quant_scales.value().size() == 4, "Expecting 4 quant scales for W4A8_MXFP4_MXFP8 quantization"); + + auto const fc1_weight_block = quant_scales.value()[0]; + auto const fc1_global = quant_scales.value()[1]; + auto const fc2_weight_block = quant_scales.value()[2]; + auto const fc2_global = quant_scales.value()[3]; + + // The input for scale fc1_weight_block / fc2_weight_block is packed into INT32 + constexpr int FP8_PER_INT32 = 4; + CHECK_INPUT(fc1_weight_block, c10::ScalarType::Int); + CHECK_INPUT(fc1_global, c10::ScalarType::Float); + CHECK_INPUT(fc2_weight_block, c10::ScalarType::Int); + CHECK_INPUT(fc2_global, c10::ScalarType::Float); + TORCH_CHECK(fc1_weight_block.dim() == 3, "fc1 weight block must be #D"); + TORCH_CHECK(fc1_global.dim() == 1, "fc1 global must be 1D"); + TORCH_CHECK(fc2_weight_block.dim() == 3, "fc2 weight block must be 3D"); + TORCH_CHECK(fc2_global.dim() == 1, "fc2 global must be 1D"); + TORCH_CHECK(fc1_weight_block.sizes()[0] == num_experts_on_rank + && fc1_weight_block.sizes()[1] + == TmaWarpSpecializedGroupedGemmInput::alignToSfDim( + inter_size, TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX) + * 2 + && fc1_weight_block.sizes()[2] * FP8_PER_INT32 + * TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize + == TmaWarpSpecializedGroupedGemmInput::alignToSfDim( + hidden_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentMXFPX) + * TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentMXFPX, + "fc1 weight block size must be (num_experts_on_rank, inter_size * 2, hidden_size // 4 // " + "block_scale_vector_size)"); + TORCH_CHECK(fc1_global.sizes()[0] == num_experts_on_rank, "fc1 global size must be (num_experts_on_rank,)"); + TORCH_CHECK(fc2_weight_block.sizes()[0] == num_experts_on_rank + && fc2_weight_block.sizes()[1] + == TmaWarpSpecializedGroupedGemmInput::alignToSfDim( + hidden_size, TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX) + && fc2_weight_block.sizes()[2] * FP8_PER_INT32 + * TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize + == TmaWarpSpecializedGroupedGemmInput::alignToSfDim( + inter_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentMXFPX), + "fc2 weight block size must be (num_experts_on_rank, hidden_size, inter_size // 4 // " + "block_scale_vector_size)"); + TORCH_CHECK(fc2_global.sizes()[0] == num_experts_on_rank, "fc2 global size must be (num_experts_on_rank,)"); + + return kernels::QuantParams::MXFP8MXFP4( + static_cast(fc1_weight_block.data_ptr()), + static_cast(fc1_global.data_ptr()), + static_cast(fc2_weight_block.data_ptr()), + static_cast(fc2_global.data_ptr())); +#else + TORCH_CHECK(false, "MXFP8 x MXFP4 quantization is not supported in OSS Cutlass Moe Gemm"); +#endif + } + else if (isNvfp4Quant()) { TORCH_CHECK(quant_scales.has_value(), "Expecting quant scales for nvfp4 quantization"); @@ -741,19 +829,27 @@ class FusedMoeRunner : public torch::CustomClassHolder TORCH_CHECK(fc1_act_global.dim() == 0 || fc1_act_global.sizes()[0] == num_experts_on_rank, "fc1 act global must be scalar or (num_experts_on_rank,)"); TORCH_CHECK(fc1_weight_block.sizes()[0] == num_experts_on_rank - && fc1_weight_block.sizes()[1] == inter_size * 2 + && fc1_weight_block.sizes()[1] + == TmaWarpSpecializedGroupedGemmInput::alignToSfDim( + inter_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentNVFP4) + * 2 && fc1_weight_block.sizes()[2] * FP8_PER_INT32 * TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize - == hidden_size, + == TmaWarpSpecializedGroupedGemmInput::alignToSfDim( + hidden_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentNVFP4), "fc1 weight block size must be (num_experts_on_rank, inter_size * 2, hidden_size // 4 // " "block_scale_vector_size)"); TORCH_CHECK(fc1_global.sizes()[0] == num_experts_on_rank, "fc1 global size must be (num_experts_on_rank,)"); TORCH_CHECK(fc2_act_global.dim() == 0 || fc2_act_global.sizes()[0] == num_experts_on_rank, "fc2 act global must be scalar or (num_experts_on_rank,)"); - TORCH_CHECK(fc2_weight_block.sizes()[0] == num_experts_on_rank && fc2_weight_block.sizes()[1] == hidden_size + TORCH_CHECK(fc2_weight_block.sizes()[0] == num_experts_on_rank + && fc2_weight_block.sizes()[1] + == TmaWarpSpecializedGroupedGemmInput::alignToSfDim( + hidden_size, TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentNVFP4) && fc2_weight_block.sizes()[2] * FP8_PER_INT32 * TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize - == inter_size, + == TmaWarpSpecializedGroupedGemmInput::alignToSfDim( + inter_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentNVFP4), "fc2 weight block size must be (num_experts_on_rank, hidden_size, inter_size // 4 // " "block_scale_vector_size)"); TORCH_CHECK(fc2_global.sizes()[0] == num_experts_on_rank, "fc2 global size must be (num_experts_on_rank,)"); @@ -821,9 +917,16 @@ class FusedMoeRunner : public torch::CustomClassHolder return mActivationDtype == c10::ScalarType::Float8_e4m3fn && isInt4Quant(); } - bool isWFp4AFp8Quant() const + bool isWMxfp4AFp8Quant() const + { + return mActivationDtype == c10::ScalarType::Float8_e4m3fn && mWeightDtype == c10::ScalarType::Long + && !mUseMxfp8ActScaling; + } + + bool isWMxfp4AMxfp8Quant() const { - return mActivationDtype == c10::ScalarType::Float8_e4m3fn && mWeightDtype == c10::ScalarType::Long; + return mActivationDtype == c10::ScalarType::Float8_e4m3fn && mWeightDtype == c10::ScalarType::Long + && mUseMxfp8ActScaling; } }; @@ -832,7 +935,7 @@ class FusedMoeRunner : public torch::CustomClassHolder TORCH_LIBRARY(trtllm, m) { m.class_("FusedMoeRunner") - .def(torch::init()) + .def(torch::init()) .def("run_gemm_profile", &torch_ext::FusedMoeRunner::runGemmProfile) .def("get_tactic_num", &torch_ext::FusedMoeRunner::getTacticNum) .def("run_moe", &torch_ext::FusedMoeRunner::runMoe) diff --git a/cpp/tensorrt_llm/thop/moeUtilOp.cpp b/cpp/tensorrt_llm/thop/moeUtilOp.cpp index 01dfb9217e2..d939bcd07fc 100644 --- a/cpp/tensorrt_llm/thop/moeUtilOp.cpp +++ b/cpp/tensorrt_llm/thop/moeUtilOp.cpp @@ -82,9 +82,8 @@ void runPermute(void const* input_activations_void, void const* input_sf_void, i cutlass_kernels::expandInputRowsKernelLauncher(input_activations, reinterpret_cast(permuted_data_), token_topk_unpermuted_scales, permuted_token_final_scales_, permuted_row_to_unpermuted_row_, num_rows, hidden_size, experts_per_token, - num_experts_per_node, quant_params.fp4.fc1.act_global_scale, /*use_per_expert_act_scale*/ false, - expert_first_token_offset_, - /* fc1_fp4_act_scale_ */ nullptr, input_sf, stream); + num_experts_per_node, quant_params, /*use_per_expert_act_scale*/ false, expert_first_token_offset_, + /* fc1_fp4_act_scale_ */ nullptr, input_sf, /* prequant_scales */ nullptr, stream); sync_check_cuda_error(stream); } diff --git a/cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu b/cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu index a44ca2a4a89..c9e4a065eb4 100644 --- a/cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu +++ b/cpp/tests/unit_tests/kernels/mixtureOfExpertsTest.cu @@ -97,8 +97,10 @@ using sizeof_bits = cutlass::sizeof_bits; - constexpr static bool FP8 = std::is_same_v && std::is_same_v; + constexpr static bool ACT_FP8 = std::is_same_v; + constexpr static bool WEIGHT_FP8 = std::is_same_v; + constexpr static bool FP8 = ACT_FP8 && WEIGHT_FP8; constexpr static bool ACT_FP4 = std::is_same_v; constexpr static bool WEIGHT_FP4 = std::is_same_v; - constexpr static bool NVFP4 = ACT_FP4 && WEIGHT_FP4; - static_assert(!NVFP4 || !MX_QUANT, "NVFP4 and MX_QUANT are be mutually exclusive"); - constexpr static bool MIXED_FP4 = !ACT_FP4 && WEIGHT_FP4; - static_assert(MIXED_FP4 || !MX_QUANT, "MIXED_FP4 is only supported with MX_QUANT"); + + constexpr static bool MX_QUANT_ACT = std::is_same_v; + constexpr static bool MX_QUANT_WEIGHT = std::is_same_v; + static_assert(!MX_QUANT_ACT || MX_QUANT_WEIGHT, "MX quantized act implies MX quantized weight"); + + constexpr static bool NVFP4 = ACT_FP4 && WEIGHT_FP4 && !MX_QUANT_ACT && !MX_QUANT_WEIGHT; + static_assert(!ACT_FP4 || NVFP4, "FP4 activations is only supported with NVFP4"); + + constexpr static bool MXFP8_MXFP4 = ACT_FP8 && WEIGHT_FP4 && MX_QUANT_ACT && MX_QUANT_WEIGHT; + constexpr static bool FP8_MXFP4 = ACT_FP8 && WEIGHT_FP4 && !MX_QUANT_ACT && MX_QUANT_WEIGHT; constexpr static bool ANY_FP4 = WEIGHT_FP4 || ACT_FP4; constexpr static bool ANY_FPX = ANY_FP4 || FP8; - constexpr static bool INT_QUANT = !std::is_same_v && !MIXED_FP4; + constexpr static bool INT_QUANT = !std::is_same_v && std::is_integral_v; constexpr static int WEIGHT_ELEM_PER_BYTE = (INT4 || WEIGHT_FP4) ? 2 : 1; - using InputType = std::conditional_t; + using InputType = std::conditional_t; using WeightStorage = std::conditional_t; constexpr static int64_t HIDDEN_SIZE_MULTIPLIER = 16; - constexpr static int64_t MINIMUM_BYTE_ALIGNMENT = 64; + constexpr static int64_t MINIMUM_BYTE_ALIGNMENT + = MX_QUANT_WEIGHT ? 64 : 128 / 8; // TMA requires 128 bits alignment, MX quant requires 64 bytes constexpr static int64_t MINIMUM_ALIGNMENT = MINIMUM_BYTE_ALIGNMENT * WEIGHT_ELEM_PER_BYTE / sizeof(WeightStorage); constexpr static int64_t DEFAULT_HIDDEN_SIZE = HIDDEN_SIZE_MULTIPLIER * MINIMUM_ALIGNMENT; // FP4 uses the unquantized data type for inputs and quantizes on the fly - using DataType = std::conditional_t; + using DataType = std::conditional_t; - // MIXED_FP4 quantizes just the weights on the fly - using WeightRawType = std::conditional_t; + // FP8_MXFP4 quantizes just the weights on the fly + using WeightRawType = std::conditional_t; static BufferManager::CudaStreamPtr mStream; static std::unique_ptr mBufferManager; @@ -165,14 +177,15 @@ protected: float getTolerance(float scale = 1.f) { - bool loose_fp8 = mActType != ActivationType::Relu; + bool loose_tol = mActType != ActivationType::Relu || mUseBias; float tol = std::is_same_v ? 0.1 : std::is_same_v ? 0.1 : std::is_same_v ? 0.001 : std::is_same_v ? 0.005 : std::is_same_v ? 0.05 - : std::is_same_v ? (loose_fp8 ? 0.06 : 0.001) - : std::is_same_v ? 0.05 + : (MXFP8_MXFP4 || FP8_MXFP4) ? (loose_tol ? 0.1 : 0.01) + : std::is_same_v ? (loose_tol ? 0.06 : 0.001) + : NVFP4 ? 0.05 : 0.0; // Keep the scale in a sane range @@ -221,6 +234,7 @@ protected: { managed_buffers.clear(); ASSERT_EQ(cudaStreamSynchronize(mStream->get()), cudaSuccess); + ASSERT_EQ(cudaDeviceSynchronize(), cudaSuccess); ASSERT_EQ(cudaGetLastError(), cudaSuccess); } @@ -282,10 +296,21 @@ protected: float* mExpertFP4WeightGlobalScale2{}; using ElementSF = TmaWarpSpecializedGroupedGemmInput::ElementSF; - constexpr static int FP4VecSize = MX_QUANT ? TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize - : TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize; - constexpr static int MinAlignmentFP4 = MX_QUANT ? TmaWarpSpecializedGroupedGemmInput::MinNumRowsAlignmentMXFPX - : TmaWarpSpecializedGroupedGemmInput::MinNumRowsAlignmentNVFP4; + constexpr static int FP4VecSize = MX_QUANT_WEIGHT ? TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize + : TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize; +#ifdef USING_OSS_CUTLASS_MOE_GEMM + constexpr static int MinNDimAlignmentFP4 = MX_QUANT_WEIGHT + ? TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX + : TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentNVFP4; + constexpr static int MinKDimAlignmentFP4 = MX_QUANT_WEIGHT + ? TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentMXFPX + : TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentNVFP4; +#else + constexpr static int MinNDimAlignmentFP4 = MX_QUANT_WEIGHT + ? TmaWarpSpecializedGroupedGemmInput::MinNumRowsAlignmentMXFPX + : TmaWarpSpecializedGroupedGemmInput::MiNumRowsAlignmentNVFP4; + constexpr static int MinKDimAlignmentFP4 = FP4VecSize * 4; // Hardcode the correct value +#endif ElementSF* mFP4ScalingFactorsW1 = nullptr; ElementSF* mFP4ScalingFactorsW2 = nullptr; @@ -460,16 +485,20 @@ protected: } else if constexpr (ANY_FP4) { - // TODO We populate these on the fly, so we can probably reduce these by moe_parallel_size mExpertWeight1 = allocBuffer( expert_matrix_size * mGatedMultiplier / WEIGHT_ELEM_PER_BYTE / moe_parallel_size); mExpertWeight2 = allocBuffer(expert_matrix_size / WEIGHT_ELEM_PER_BYTE / moe_parallel_size); - size_t const padded_fc1_size = mNumExperts * mHiddenSize - * cute::ceil_div(mInterSize * mGatedMultiplier / parallelism_config.tp_size, MinAlignmentFP4) - * MinAlignmentFP4 / parallelism_config.ep_size; - size_t const padded_fc2_size = mNumExperts * mInterSize * cute::ceil_div(mHiddenSize, MinAlignmentFP4) - * MinAlignmentFP4 / moe_parallel_size; + size_t const padded_fc1_size = mNumExperts + * TmaWarpSpecializedGroupedGemmInput::alignToSfDim(mHiddenSize, MinKDimAlignmentFP4) + * TmaWarpSpecializedGroupedGemmInput::alignToSfDim( + mInterSize / parallelism_config.tp_size, MinNDimAlignmentFP4) + * mGatedMultiplier / parallelism_config.ep_size; + size_t const padded_fc2_size = mNumExperts + * TmaWarpSpecializedGroupedGemmInput::alignToSfDim(mInterSize, MinKDimAlignmentFP4) + * TmaWarpSpecializedGroupedGemmInput::alignToSfDim( + mHiddenSize / parallelism_config.tp_size, MinNDimAlignmentFP4) + / parallelism_config.ep_size; mFP4ScalingFactorsW1 = allocBuffer(padded_fc1_size / FP4VecSize); mFP4ScalingFactorsW2 = allocBuffer(padded_fc2_size / FP4VecSize); } @@ -572,58 +601,50 @@ protected: void doFP4Quant(WeightRawType const* raw_weights, WeightStorage* quant_weights, float const* global_scales, ElementSF* scaling_factors, int in_shape, int out_shape, int num_experts) { - int const mMultiProcessorCount = tensorrt_llm::common::getMultiProcessorCount(); - int padded_stride = cute::ceil_div(out_shape, MinAlignmentFP4) * MinAlignmentFP4; + int64_t const mMultiProcessorCount = tensorrt_llm::common::getMultiProcessorCount(); + int64_t padded_out_dim = TmaWarpSpecializedGroupedGemmInput::alignToSfDim(out_shape, MinNDimAlignmentFP4); + int64_t padded_in_dim = TmaWarpSpecializedGroupedGemmInput::alignToSfDim(in_shape, MinKDimAlignmentFP4); check_cuda_error(cudaMemsetAsync(scaling_factors, 0x00, - num_experts * padded_stride * cutlass::ceil_div(in_shape, FP4VecSize) * sizeof(ElementSF), mStream->get())); + num_experts * padded_out_dim * padded_in_dim / FP4VecSize * sizeof(ElementSF), mStream->get())); invokeBatchedFP4Quantization(num_experts, out_shape, in_shape, raw_weights, global_scales, reinterpret_cast(quant_weights), reinterpret_cast(scaling_factors), - MX_QUANT, mMultiProcessorCount, mStream->get()); - // for (int i = 0; i < num_experts; i++) - // { - // auto* weight_start = raw_weights + i * in_shape * out_shape; - // auto* quant_weight_start = quant_weights + i * in_shape * out_shape / WEIGHT_ELEM_PER_BYTE; - // auto* scaling_factor_start - // = scaling_factors + i * (int64_t) padded_stride * cutlass::ceil_div(in_shape, FP4VecSize); - // printf("Expert %d: Weight offset: %lld, quant_weight_offset: %lld, scaling_factor_offset: %lld\n", - // (long long) i, (long long) i * in_shape * out_shape, - // (long long) i * in_shape * out_shape / WEIGHT_ELEM_PER_BYTE, - // (long long) i * (int64_t) padded_stride * cutlass::ceil_div(in_shape, FP4VecSize)); - - // check_cuda_error(cudaStreamSynchronize(mStream->get())); - // std::cout << "Quant " << i << " starting" << std::endl; - // auto data = getDataFromDevice(scaling_factor_start, 4 * cutlass::ceil_div(in_shape, FP4VecSize)); - // for (auto v : data) - // { - // std::cout << (float) v << ", "; + MX_QUANT_WEIGHT, mMultiProcessorCount, mStream->get()); + + // auto sf_data = getDataFromDevice(scaling_factors, num_experts * padded_out_dim * padded_in_dim / + // FP4VecSize); auto unquant_data = getDataFromDevice(raw_weights, num_experts * out_shape * + // in_shape); auto quant_data = getDataFromDevice((uint32_t*)quant_weights, num_experts * out_shape * + // in_shape / 8); for(int expert = 0; expert < num_experts; expert++) { + // for(int i = 0; i < out_shape; i++) { + // for(int j = 0; j < in_shape / FP4VecSize; j++) { + // printf("quant_weights[(%d, %d, %d)]: ", expert, i, j * FP4VecSize); + // for(int k = 0; k < FP4VecSize / 8; k++) { + // printf("0x%08x, ", quant_data[(expert * out_shape * in_shape + i * in_shape + j * FP4VecSize) + // / 8 + k]); + // } + // printf("scaling_factors: %e, ", + // (float)sf_data[tensorrt_llm::kernels::get_sf_out_offset_128x4(expert, i, j, + // out_shape, in_shape)]); printf("original: "); for(int k = 0; k < FP4VecSize; k++) { + // printf("%e, ", (float)unquant_data[expert * out_shape * in_shape + i * in_shape + j * + // FP4VecSize + k]); + // } + // printf("\n"); + // } // } - // std::cout << std::endl; - - // invokeFP4Quantization(out_shape, in_shape, weight_start, global_scales + i, - // reinterpret_cast(quant_weight_start), reinterpret_cast(scaling_factor_start), - // MX_QUANT, tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED, mMultiProcessorCount, mStream->get()); - - // // check_cuda_error(cudaStreamSynchronize(mStream->get())); - // // std::cout << "Quant " << i << " done" << std::endl; - // // auto data = getDataFromDevice(mInputTensor, mHiddenSize); - // // for (auto v : data) - // // { - // // std::cout << (float)v << ", "; - // // } - // // std::cout << std::endl; // } } constexpr static float getFPXActScalar(float in) { - // Our FP8 x MXFP4 implementation uses a global scale factor. This should be skipped if we use MXFP8 x MXFP4 - if (FP8 || MIXED_FP4) + // Our FP8 x MXFP4 implementation uses a global scale factor + if (FP8 || FP8_MXFP4) return FP8_MAX / in; if (NVFP4) // We need to represent the block SF using FP8, so the largest value should be at most FP4_MAX * FP8_MAX // return FP8_MAX * FP4_MAX / in; // We carefully control precision in FP4. We want to avoid introducing any non-powers of two return 2.0f; + + // MX quant does not have a global scale factor return 1.0f; } @@ -770,8 +791,45 @@ protected: template auto populateTokens(std::vector& hidden_states) { - // Can't use FP8 param because we recurse with a different type, and we also reuse this for MIXED_FP4 - if constexpr (std::is_same_v) + if constexpr (MX_QUANT_ACT) // MXFP8_MXFP4 + { + int const max_order_of_magnitude = 4; + std::vector base(hidden_states.size()); + std::mt19937 gen(0xD5); + // Filthy hack to make GELU/SiLu be not introduce large quantization errors + float min = mIsGated ? 4.f : 0; + float max = FP8_MAX; + std::uniform_int_distribution is_negative(0, 10); + std::uniform_real_distribution dist(min, max); + std::generate(base.begin(), base.end(), + [&]() + { + if (is_negative(gen) == 0) + { + return float(__nv_fp8_e4m3(-dist(gen))); + } + else + { + return float(__nv_fp8_e4m3(dist(gen))); + } + }); + + // Avoid small values for gated activation + int adjustment = max_order_of_magnitude / 2; + for (int i = 0; i < hidden_states.size() / FP4VecSize; i++) + { + auto block_scale = mIsGated ? 1.f : exp2f(i % max_order_of_magnitude - adjustment); + hidden_states[i * FP4VecSize] = T(FP8_MAX * block_scale); + for (int j = 1; j < FP4VecSize; j++) + { + hidden_states[i * FP4VecSize + j] = T(base[i * FP4VecSize + j] * block_scale); + } + mMaxInput = std::max(mMaxInput, FP8_MAX * block_scale); + } + return hidden_states; + } + // Use the actual template param because we recurse with a different type + else if constexpr (std::is_same_v) // FP8, FP8_MXFP4 { // Call the standard setup and then perform the quantization manually std::vector internal_states(hidden_states.size()); @@ -786,7 +844,7 @@ protected: [scalar](T in) -> OutputType { return static_cast(((float) in) / scalar); }); return internal_states; } - else if constexpr (ACT_FP4) + else if constexpr (ACT_FP4) // NVFP4 { float const max_scale = 1.0f; mMaxInput = FP4_MAX * max_scale; @@ -796,7 +854,8 @@ protected: int stride = FP4VecSize; for (int i = 0; i < hidden_states.size(); i += stride) { - for (int j = 0; j < stride; j++) + hidden_states[i] = FP4_MAX * scale; + for (int j = 1; j < stride; j++) { hidden_states[i + j] = allowed_values[(i / stride + j) % allowed_values.size()] * scale; } @@ -809,7 +868,7 @@ protected: } return hidden_states; } - else + else // FP16, BF16, FP32, (recurse) FP8 { // Generates numbers in increments of 1/max_order_of_magnitude in the range [0, 1) constexpr int max_order_of_magnitude = 256; @@ -1078,13 +1137,21 @@ protected: ASSERT_TRUE(mExpertFP4ActGlobalScale1); ASSERT_TRUE(mFP4ScalingFactorsW1 && mFP4ScalingFactorsW2); ASSERT_TRUE(scale1_ptr && scale2_ptr && scale3_ptr); - auto fc1_sf_offset = mUsePerExpertActScale && NVFP4 - ? mNumExperts / parallelism_config.ep_size * parallelism_config.ep_rank - : 0; - auto constructor = NVFP4 ? &QuantParams::FP4 : &QuantParams::FP8MXFP4; - quant_params = constructor(mExpertFP4ActGlobalScale1 + fc1_sf_offset, mFP4ScalingFactorsW1, - static_cast(scale1_ptr), static_cast(scale2_ptr), mFP4ScalingFactorsW2, - static_cast(scale3_ptr), mUsePerExpertActScale && NVFP4, mUsePerExpertActScale); + if constexpr (NVFP4 || FP8_MXFP4) + { + auto fc1_sf_offset = mUsePerExpertActScale && NVFP4 + ? mNumExperts / parallelism_config.ep_size * parallelism_config.ep_rank + : 0; + auto constructor = NVFP4 ? &QuantParams::FP4 : &QuantParams::FP8MXFP4; + quant_params = constructor(mExpertFP4ActGlobalScale1 + fc1_sf_offset, mFP4ScalingFactorsW1, + static_cast(scale1_ptr), static_cast(scale2_ptr), mFP4ScalingFactorsW2, + static_cast(scale3_ptr), mUsePerExpertActScale && NVFP4, mUsePerExpertActScale); + } + else if constexpr (MXFP8_MXFP4) + { + quant_params = QuantParams::MXFP8MXFP4(mFP4ScalingFactorsW1, static_cast(scale1_ptr), + mFP4ScalingFactorsW2, static_cast(scale3_ptr)); + } } if constexpr (WEIGHT_FP4) @@ -1204,7 +1271,18 @@ protected: return in; } - float calcMLPVal(float input, int expert_id, bool final_bias = false) + float quantAct(float in, float block_max) + { + if (MX_QUANT_ACT) + { + float scale = std::exp2f(std::ceil(std::log2f(block_max / FP8_MAX))); + return float(__nv_fp8_e4m3(in / scale)) * scale; + } + // TODO Handle NVFP4 too so we can test non-relu actfns + return in; + } + + float calcMLPVal(float input, int expert_id, bool final_bias = false, float block_max = 1.f) { if (expert_id >= mNumExperts) return 0; @@ -1221,22 +1299,28 @@ protected: float gate = input * gated_scalar + gated_bias; activated = fc1 * actfn(gate); + + block_max = (block_max * scalar + w1_bias) * actfn(block_max * gated_scalar + gated_bias); } else { float scalar = applyExpertShift(mExpertWDiag1, expert_id); float fc1 = input * scalar + w1_bias; activated = actfn(fc1); + + block_max = actfn(block_max * scalar + w1_bias); } + activated = quantAct(activated, block_max); + EXPECT_TRUE(mUseBias || !final_bias); float result = activated * applyExpertShift(mExpertWDiag2, expert_id) + (float) (final_bias ? expert_id : 0); return result; } - float calcMLPValWithFinalBias(float input, int expert_id) + float calcMLPValWithFinalBias(float input, int expert_id, float block_max = 1.f) { - return calcMLPVal(input, expert_id, mUseBias); + return calcMLPVal(input, expert_id, mUseBias, block_max); } template @@ -1313,9 +1397,14 @@ protected: for (int64_t token_id = 0; token_id < mTotalTokens; token_id++) { + float block_max = 1.f; // NOTE: When mInterSize < mHiddenSize, those values get zeroed out by fc1 and lost for (int64_t hidden_id = 0; hidden_id < std::min(mHiddenSize, mInterSize); hidden_id++) { + if (MX_QUANT_ACT && hidden_id % FP4VecSize == 0) + { + block_max = input_data[token_id * mHiddenSize + hidden_id]; + } float sum = 0.0f; // Loop for the number of times each token is duplicated for (int k_idx = 0; k_idx < mK; k_idx++) @@ -1323,8 +1412,9 @@ protected: int selected_expert = expected_experts[token_id * mK + k_idx]; float final_scale_value = token_final_scales[token_id * mK + k_idx]; - float final_value = float(calcMLPValWithFinalBias( - static_cast(input_data[token_id * mHiddenSize + hidden_id]), selected_expert)); + float final_value = float( + calcMLPValWithFinalBias(static_cast(input_data[token_id * mHiddenSize + hidden_id]), + selected_expert, block_max)); sum += final_value * final_scale_value; } @@ -1354,13 +1444,12 @@ protected: } // Tensor parallel tests default to inter_size_fraction = 1.0f so that all ranks have interesting values (i.e. a - // diagonal non-square matrix would be all zeros for the last rank) Note when debugging we occasionally want to edit - // the HIDDEN_SIZE_MULTIPLIER to a smaller value to make inspecting weights easier, so account for this so the test - // doesn't fail + // diagonal non-square matrix would be all zeros for the last rank) void TensorParallelTest(int k = 1, int64_t hidden_size = DEFAULT_HIDDEN_SIZE, int64_t num_experts = 4, - int64_t num_tokens = 3, float inter_size_fraction = std::min(1.0f, HIDDEN_SIZE_MULTIPLIER / 8.0f)) + int64_t num_tokens = 3, float inter_size_fraction = 1.0f) { - mInterSizeFraction = inter_size_fraction; + // Ensure we dont drop below the minimum alignment + mInterSizeFraction = std::max(inter_size_fraction, MINIMUM_ALIGNMENT * 8.0f / hidden_size); ParallelismTest(k, 2, 1, hidden_size, num_experts, num_tokens); ParallelismTest(k, 4, 1, hidden_size, num_experts, num_tokens); ParallelismTest(k, 8, 1, hidden_size, num_experts, num_tokens); @@ -1369,7 +1458,7 @@ protected: void MixedParallelTest(int k = 1, int64_t hidden_size = DEFAULT_HIDDEN_SIZE, int64_t num_experts = 4, int64_t num_tokens = 3, float inter_size_fraction = 1.0f) { - mInterSizeFraction = inter_size_fraction; + mInterSizeFraction = std::max(inter_size_fraction, MINIMUM_ALIGNMENT * 8.0f / hidden_size); // 2 experts per rank ParallelismTest(k, 2, num_experts / 2, hidden_size, num_experts, num_tokens); @@ -1387,13 +1476,15 @@ protected: template using LargeMixtureOfExpertsTest = MixtureOfExpertsTest; -template +template struct WeightParams { using DataType = DataType_; using WeightType = WeightType_; using OutputType = OutputType_; - constexpr static bool UseMxQuant = MX_QUANT_; + using ActivationScale = ActivationScale_; + using WeightScale = WeightScale_; }; // TODO Fix int quantized @@ -1405,7 +1496,12 @@ using Types = ::testing::Types< WeightParams, #endif #ifdef ENABLE_FP4 - WeightParams, WeightParams, + WeightParams, + WeightParams, + +#ifdef USING_OSS_CUTLASS_MOE_GEMM + WeightParams, +#endif #endif WeightParams, WeightParams @@ -1651,6 +1747,12 @@ TYPED_TEST(MixtureOfExpertsTest, PermuteDeepSeekV3) this->BasicPermuteTest(8, hidden_size, 256, 100); } +TYPED_TEST(MixtureOfExpertsTest, MinimumAlignment) +{ + this->mInterSizeFraction = 1; + this->BasicPermuteTest(1, this->DEFAULT_HIDDEN_SIZE + this->MINIMUM_ALIGNMENT); +} + template std::vector MixtureOfExpertsTest::calcPermuteMapExpertParallel( std::vector const& expected_experts) @@ -1684,7 +1786,7 @@ void MixtureOfExpertsTest::ParallelismTest( { if (mActType != ActivationType::Relu) { - // FP4 has far too little precision to get any sort of consistency with non-relu actfn + // FP4 has too little precision to get any sort of consistency with non-relu actfn GTEST_SKIP(); return; } diff --git a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py index a1b169689e4..810d3c33a69 100644 --- a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py @@ -42,6 +42,7 @@ def __init__( enable_alltoall: bool, use_deepseek_fp8_block_scale: bool, use_w4a8_group_scaling: bool, + use_mxfp8_act_scaling: bool, min_latency_mode: bool, ): self.x_dtype = x_dtype @@ -57,15 +58,18 @@ def __init__( self.enable_alltoall = enable_alltoall self.use_deepseek_fp8_block_scale = use_deepseek_fp8_block_scale self.use_w4a8_group_scaling = use_w4a8_group_scaling + self.use_mxfp8_act_scaling = use_mxfp8_act_scaling self.min_latency_mode = min_latency_mode instance_key = (x_dtype, weight_dtype, output_dtype, - use_deepseek_fp8_block_scale, use_w4a8_group_scaling) + use_deepseek_fp8_block_scale, use_w4a8_group_scaling, + use_mxfp8_act_scaling) if instance_key not in MoERunner.runner_dict: MoERunner.runner_dict[ instance_key] = torch.classes.trtllm.FusedMoeRunner( x_dtype, weight_dtype, output_dtype, - use_deepseek_fp8_block_scale, use_w4a8_group_scaling) + use_deepseek_fp8_block_scale, use_w4a8_group_scaling, + use_mxfp8_act_scaling) self.fused_moe_runner = MoERunner.runner_dict[instance_key] def get_valid_tactics( @@ -134,6 +138,7 @@ def fused_moe( enable_alltoall: bool = False, use_deepseek_fp8_block_scale: bool = False, use_w4a8_group_scaling: bool = False, + use_mxfp8_act_scaling: bool = False, min_latency_mode: bool = False, tune_max_num_tokens: int = 8192, ) -> List[torch.Tensor]: @@ -156,6 +161,7 @@ def fused_moe( enable_alltoall=enable_alltoall, use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale, use_w4a8_group_scaling=use_w4a8_group_scaling, + use_mxfp8_act_scaling=use_mxfp8_act_scaling, min_latency_mode=min_latency_mode, ) @@ -227,6 +233,7 @@ def _( enable_alltoall: bool = False, use_deepseek_fp8_block_scale: bool = False, use_w4a8_group_scaling: bool = False, + use_mxfp8_act_scaling: bool = False, min_latency_mode: bool = False, tune_max_num_tokens: int = 8192, ):