From e1a1278d516d5ba60be080f22b402d04bc6e9c68 Mon Sep 17 00:00:00 2001 From: Michel Belleau Date: Tue, 13 Jan 2026 08:06:57 -0500 Subject: [PATCH] [Fix][MoE] Add SM120 support for FP8 MoE path Add SM120 (Blackwell) support to cutlass_moe_mm to enable FP8 MoE models (GLM-4.7, MiniMax M2.1) on RTX PRO 6000 Blackwell GPUs. Changes: - Add cutlass_moe_mm_sm120 function declaration - Add SM120 conditional branch (version_num >= 120 && version_num < 130) - Create csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm120.cu implementation - Update error message to include SM120 in supported capabilities Note: This is currently untested. Will be battle tested on RTX PRO 6000 Blackwell hardware with GLM-4.7-FP8. Performance results will be added to the PR description after testing. Fixes: #32109 --- .../w8a8/cutlass/moe/grouped_mm_c3x_sm120.cu | 139 ++++++++++++++++++ .../w8a8/cutlass/scaled_mm_entry.cu | 20 ++- 2 files changed, 158 insertions(+), 1 deletion(-) create mode 100644 csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm120.cu diff --git a/csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm120.cu b/csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm120.cu new file mode 100644 index 000000000000..92f1372cf2d2 --- /dev/null +++ b/csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm120.cu @@ -0,0 +1,139 @@ +#include + +#include +#include + +#include "cutlass/cutlass.h" +#include "grouped_mm_c3x.cuh" + +using namespace cute; + +namespace { + +template typename Epilogue> +struct sm120_fp8_config_default { + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm120; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + using ArchTag = cutlass::arch::Sm120; + + using Cutlass3xGemm = + cutlass_3x_group_gemm; +}; + +template typename Epilogue> +struct sm120_fp8_config_M64 { + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm120; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + using ArchTag = cutlass::arch::Sm120; + + using Cutlass3xGemm = + cutlass_3x_group_gemm; +}; + +template typename Epilogue> +struct sm120_fp8_config_N8192 { + static_assert(std::is_same()); + using KernelSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm120; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + using ArchTag = cutlass::arch::Sm120; + + using Cutlass3xGemm = + cutlass_3x_group_gemm; +}; + +template +void run_cutlass_moe_mm_sm120( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides, + bool per_act_token, bool per_out_ch) { + TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided."); + TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided."); + TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided."); + + TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn, + "A tensors must be of type float8_e4m3fn."); + TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn, + "B tensors must be of type float8_e4m3fn."); + + using Cutlass3xGemmDefault = typename sm120_fp8_config_default< + InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; + using Cutlass3xGemmN8192 = typename sm120_fp8_config_N8192< + InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; + using Cutlass3xGemmM64 = typename sm120_fp8_config_M64< + InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm; + + uint32_t const m = a_tensors.size(0); + uint32_t const n = out_tensors.size(1); + + if (m <= 64) { + cutlass_group_gemm_caller( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides, per_act_token, + per_out_ch); + } else if (n >= 8192) { + cutlass_group_gemm_caller( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides, per_act_token, + per_out_ch); + } else { + cutlass_group_gemm_caller( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides, per_act_token, + per_out_ch); + } +} + +void dispatch_moe_mm_sm120( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides, + bool per_act_token, bool per_out_ch) { + if (out_tensors.dtype() == torch::kBFloat16) { + run_cutlass_moe_mm_sm120( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides, per_act_token, + per_out_ch); + } else { + run_cutlass_moe_mm_sm120( + out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, + problem_sizes, a_strides, b_strides, c_strides, per_act_token, + per_out_ch); + } +} + +} // namespace + +void cutlass_moe_mm_sm120( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides, + bool per_act_token, bool per_out_ch) { + dispatch_moe_mm_sm120(out_tensors, a_tensors, b_tensors, a_scales, b_scales, + expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, per_act_token, per_out_ch); +} \ No newline at end of file diff --git a/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu b/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu index 077966a1d92a..a9f9e2fe24b2 100644 --- a/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu +++ b/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu @@ -51,6 +51,16 @@ void cutlass_moe_mm_sm100( bool per_act_token, bool per_out_ch); #endif +#if defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120 +void cutlass_moe_mm_sm120( + torch::Tensor& out_tensors, torch::Tensor const& a_tensors, + torch::Tensor const& b_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, + torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, + torch::Tensor const& b_strides, torch::Tensor const& c_strides, + bool per_act_token, bool per_out_ch); +#endif + #if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120 void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, @@ -259,6 +269,14 @@ void cutlass_moe_mm( torch::Tensor const& b_strides, torch::Tensor const& c_strides, bool per_act_token, bool per_out_ch) { int32_t version_num = get_sm_version_num(); +#if defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120 + if (version_num >= 120 && version_num < 130) { + cutlass_moe_mm_sm120(out_tensors, a_tensors, b_tensors, a_scales, b_scales, + expert_offsets, problem_sizes, a_strides, b_strides, + c_strides, per_act_token, per_out_ch); + return; + } +#endif #if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100 if (version_num >= 100 && version_num < 110) { cutlass_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales, @@ -278,7 +296,7 @@ void cutlass_moe_mm( TORCH_CHECK_NOT_IMPLEMENTED( false, "No compiled cutlass_scaled_mm for CUDA device capability: ", version_num, - ". Required capability: 90 or 100"); + ". Required capability: 90, 100, or 120"); } void get_cutlass_moe_mm_data(