Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,28 @@ struct TmaWarpSpecializedGroupedGemmInput
using NVFP4BlockScaledConfig = cutlass::detail::Sm1xxBlockScaledConfig<NVFP4BlockScaleVectorSize>;
using MXFPXBlockScaledConfig = cutlass::detail::Sm1xxBlockScaledConfig<MXFPXBlockScaleVectorSize>;

constexpr static int MinNumRowsAlignmentNVFP4 = cute::size<0>(NVFP4BlockScaledConfig::SfAtom{});
constexpr static int MinNumRowsAlignmentMXFPX = cute::size<0>(MXFPXBlockScaledConfig::SfAtom{});
// 128
// This is the alignment of the weight matrix the fully padded SF will refer to.
// We require the SFs to be aligned to this value (zero padded as needed)
// The weights do not need to be aligned to this value, CUTLASS will handle extra padding
// N here is a short hand for the outer dimension of the GEMM, this applies to both M & N dimension of the GEMM
constexpr static int MinNDimAlignmentNVFP4 = cute::size<0>(NVFP4BlockScaledConfig::SfAtom{});
constexpr static int MinNDimAlignmentMXFPX = cute::size<0>(MXFPXBlockScaledConfig::SfAtom{});

// Block scale vector size * 4
// This is the alignment of the weight matrix the fully padded SF will refer to.
// We should never actually need to pad a buffer to this alignment
// The weights only need to be aligned to BlockScaleVectorSize, CUTLASS will handle extra padding
// The SFs only need to be aligned to 4 (zero padded as needed)
// K here is a short hand for the inner dimension of the GEMM
constexpr static int MinKDimAlignmentNVFP4 = cute::size<1>(NVFP4BlockScaledConfig::SfAtom{});
constexpr static int MinKDimAlignmentMXFPX = cute::size<1>(MXFPXBlockScaledConfig::SfAtom{});

// Helper function to align a dimension to the SF alignment
constexpr static int64_t alignToSfDim(int64_t dim, int64_t alignment)
{
return (dim + alignment - 1) / alignment * alignment;
}

using StrideA
= std::remove_pointer_t<cutlass::detail::TagToStrideB_t<LayoutA*>>; // Use B because they will be swapped
Expand Down
26 changes: 26 additions & 0 deletions cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,21 @@ struct QuantParams
GemmInputs fc2;
} fp8_mxfp4;

// MXFP8 MXFP4 quantization params
// This mode uses block scaled MXFP8 and MXFP4 weights
struct MXFP8MXFP4Inputs
{
struct GemmInputs
{
TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF const* weight_block_scale
= nullptr; // (experts, n, k / 32)
float const* global_scale = nullptr; // (num_experts_per_node, )
};

GemmInputs fc1;
GemmInputs fc2;
} mxfp8_mxfp4;

// FP4 quantization params
struct FP4Inputs
{
Expand Down Expand Up @@ -291,13 +306,24 @@ struct QuantParams
return qp;
}

static QuantParams MXFP8MXFP4(TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF const* fc1_weight_block_scale,
float const* fc1_global_scale, //
TmaWarpSpecializedGroupedGemmInput::MXFPXElementSF const* fc2_weight_block_scale, float const* fc2_global_scale)
{
QuantParams qp;
qp.mxfp8_mxfp4.fc1 = {fc1_weight_block_scale, fc1_global_scale};
qp.mxfp8_mxfp4.fc2 = {fc2_weight_block_scale, fc2_global_scale};
return qp;
}

static QuantParams FP4(float const* fc1_act_global_scale,
TmaWarpSpecializedGroupedGemmInput::NVFP4ElementSF const* fc1_weight_block_scale,
float const* fc1_global_scale, //
float const* fc2_act_global_scale,
TmaWarpSpecializedGroupedGemmInput::NVFP4ElementSF const* fc2_weight_block_scale,
float const* fc2_global_scale, //
bool fc1_use_per_expert_act_scale = false, bool fc2_use_per_expert_act_scale = false)

{
QuantParams qp;
qp.fp4.fc1 = {fc1_use_per_expert_act_scale, fc1_act_global_scale, fc1_weight_block_scale, fc1_global_scale};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,13 @@ void threeStepBuildExpertMapsSortFirstToken(int const* token_selected_experts, i
int64_t const num_tokens, int64_t const num_experts_per_node, int64_t const num_experts_per_token,
int const start_expert_id, cudaStream_t stream);

template <class InputActivationsType, class ExpandedActivationsType, bool PRE_QUANT_AWQ = false>
template <class InputActivationsType, class ExpandedActivationsType>
void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input,
ExpandedActivationsType* permuted_output, float const* unpermuted_scales, float* permuted_scales,
int const* permuted_row_to_unpermuted_row, int64_t const num_rows, int64_t const cols, int const k,
int const num_experts_per_node, float const* fc1_act_global_scale, bool use_per_expert_act_scale,
int const* permuted_row_to_unpermuted_row, int64_t const num_rows, int64_t const hidden_size, int const k,
int const num_experts_per_node, QuantParams const& quant_params, bool use_per_expert_act_scale,
int64_t* expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat,
TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, cudaStream_t stream,
void const* prequant_scales = nullptr);
TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, void const* prequant_scales, cudaStream_t stream);

template <class OutputType, class GemmOutputType, class ScaleBiasType>
void finalizeMoeRoutingKernelLauncher(GemmOutputType const* expanded_permuted_rows,
Expand Down
Loading