From bd5e7869a811d786e99b0df3b9f39bce0d2ff16c Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 15 Apr 2026 08:43:54 +0000 Subject: [PATCH 1/2] =?UTF-8?q?Add=20sm120=5Ffp4=5Fconfig=5Fsmall=5Fm=20(1?= =?UTF-8?q?28x128x256)=20for=20M=E2=89=A432=20and=203-tier=20dispatch?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reference: sglang PR #21314 - New tile config sm120_fp4_config_small_m with MmaTileShape 128x128x256 for small M values (M ≤ 32), doubling K tile for better throughput - Updated dispatch: M≤32 → small_m, M≤256 → M256, M>256 → default - ~20% speedup for decode-phase small-batch GEMM operations Co-authored-by: GitHub Copilot Agent-Logs-Url: https://github.com/Nekofish-L/vllm/sessions/66285e45-f69c-404b-975a-4afc5d3edb4e Co-authored-by: Nekofish-L <29830327+Nekofish-L@users.noreply.github.com> --- .../fp4/nvfp4_scaled_mm_sm120_kernels.cu | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu b/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu index b500ae5a0a74..fd8c08f7822d 100644 --- a/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu +++ b/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu @@ -48,12 +48,21 @@ using namespace cute; constexpr auto FLOAT4_E2M1X2 = torch::headeronly::ScalarType::Byte; constexpr auto SF_DTYPE = torch::headeronly::ScalarType::Float8_e4m3fn; +// Configuration for M in [1, 32] - uses larger K tile for better throughput +struct sm120_fp4_config_small_m { + using ClusterShape = Shape<_1, _1, _1>; + using MmaTileShape = Shape<_128, _128, _256>; + using PerSmTileShape_MNK = Shape<_128, _128, _256>; +}; + +// Configuration for M in (32, 256] struct sm120_fp4_config_M256 { using ClusterShape = Shape<_1, _1, _1>; using MmaTileShape = Shape<_128, _128, _128>; using PerSmTileShape_MNK = Shape<_128, _128, _128>; }; +// Configuration for M in (256, inf) struct sm120_fp4_config_default { using ClusterShape = Shape<_1, _1, _1>; using MmaTileShape = Shape<_256, _128, _128>; @@ -188,7 +197,10 @@ void cutlass_fp4_bf16_gemm_dispatch(torch::stable::Tensor& D, torch::stable::Tensor const& alpha, int m, int n, int k, cudaStream_t stream) { uint32_t const mp2 = std::max(static_cast(16), next_pow_2(m)); - if (mp2 <= 256) { + if (mp2 <= 32) { + runGemm::Gemm>( + D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } else if (mp2 <= 256) { runGemm::Gemm>( D, A, B, A_sf, B_sf, alpha, m, n, k, stream); } else { @@ -205,7 +217,10 @@ void cutlass_fp4_f16_gemm_dispatch(torch::stable::Tensor& D, torch::stable::Tensor const& alpha, int m, int n, int k, cudaStream_t stream) { uint32_t const mp2 = std::max(static_cast(16), next_pow_2(m)); - if (mp2 <= 256) { + if (mp2 <= 32) { + runGemm::Gemm>( + D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + } else if (mp2 <= 256) { runGemm::Gemm>( D, A, B, A_sf, B_sf, alpha, m, n, k, stream); } else { From 45df94974c04021880138e7a0736686ec1c89b7d Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Wed, 15 Apr 2026 09:21:59 +0000 Subject: [PATCH 2/2] feat: add swapAB optimization for SM120 NVFP4 GEMM kernel MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When M is small (≤64), swap A/B operands so the small M dimension becomes the N dimension in the CUTLASS GEMM. This improves GPU utilization during decode by providing better CTA scheduling and memory access patterns. Follows the same pattern used in FP8 SM90, SM100, and SM120 blockwise kernels. Co-authored-by: GitHub Copilot Agent-Logs-Url: https://github.com/Nekofish-L/vllm/sessions/86332631-5db7-485e-8d7f-3f51fce66977 Co-authored-by: Nekofish-L <29830327+Nekofish-L@users.noreply.github.com> --- .../fp4/nvfp4_scaled_mm_sm120_kernels.cu | 125 ++++++++++++------ 1 file changed, 84 insertions(+), 41 deletions(-) diff --git a/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu b/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu index fd8c08f7822d..66090ea5c5fb 100644 --- a/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu +++ b/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu @@ -48,14 +48,15 @@ using namespace cute; constexpr auto FLOAT4_E2M1X2 = torch::headeronly::ScalarType::Byte; constexpr auto SF_DTYPE = torch::headeronly::ScalarType::Float8_e4m3fn; -// Configuration for M in [1, 32] - uses larger K tile for better throughput -struct sm120_fp4_config_small_m { +// Configuration for swapAB (used when M <= 64 for better GPU utilization) +// Uses larger K tile (256) for better throughput on small batch sizes +struct sm120_fp4_config_swapab { using ClusterShape = Shape<_1, _1, _1>; using MmaTileShape = Shape<_128, _128, _256>; using PerSmTileShape_MNK = Shape<_128, _128, _256>; }; -// Configuration for M in (32, 256] +// Configuration for M in (64, 256] struct sm120_fp4_config_M256 { using ClusterShape = Shape<_1, _1, _1>; using MmaTileShape = Shape<_128, _128, _128>; @@ -69,20 +70,30 @@ struct sm120_fp4_config_default { using PerSmTileShape_MNK = Shape<_256, _128, _128>; }; -template +template struct Fp4GemmSm120 { + static constexpr bool swap_ab = swap_ab_; + using ElementA = cutlass::nv_float4_t; using LayoutATag = cutlass::layout::RowMajor; + using LayoutATag_Transpose = + typename cutlass::layout::LayoutTranspose::type; static constexpr int AlignmentA = 32; using ElementB = cutlass::nv_float4_t; using LayoutBTag = cutlass::layout::ColumnMajor; + using LayoutBTag_Transpose = + typename cutlass::layout::LayoutTranspose::type; static constexpr int AlignmentB = 32; using ElementD = OutType; using ElementC = OutType; using LayoutCTag = cutlass::layout::RowMajor; + using LayoutCTag_Transpose = + typename cutlass::layout::LayoutTranspose::type; using LayoutDTag = cutlass::layout::RowMajor; + using LayoutDTag_Transpose = + typename cutlass::layout::LayoutTranspose::type; static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; @@ -94,22 +105,34 @@ struct Fp4GemmSm120 { using ClusterShape = typename Config::ClusterShape; using PerSmTileShape_MNK = typename Config::PerSmTileShape_MNK; + // Conditionally use transposed C/D layouts when swap_ab is enabled using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, PerSmTileShape_MNK, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, - ElementAccumulator, ElementC, LayoutCTag, AlignmentC, ElementD, - LayoutDTag, AlignmentD, + ElementAccumulator, ElementC, + conditional_t, AlignmentC, + ElementD, conditional_t, + AlignmentD, cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp; - using CollectiveMainloop = + // Conditionally swap A/B operands and their layouts in the mainloop + using CollectiveMainloop = conditional_t< + swap_ab, + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementB, LayoutBTag_Transpose, AlignmentB, + ElementA, LayoutATag_Transpose, AlignmentA, ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp, typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, OperatorClass, ElementA, LayoutATag, AlignmentA, ElementB, LayoutBTag, AlignmentB, ElementAccumulator, MmaTileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout( sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp; + cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp>; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveMainloop, CollectiveEpilogue, void>; @@ -117,14 +140,15 @@ struct Fp4GemmSm120 { using Gemm = cutlass::gemm::device::GemmUniversalAdapter; }; -template -typename Gemm::Arguments args_from_options(torch::stable::Tensor& D, - torch::stable::Tensor const& A, - torch::stable::Tensor const& B, - torch::stable::Tensor const& A_sf, - torch::stable::Tensor const& B_sf, - torch::stable::Tensor const& alpha, - int M, int N, int K) { +template +typename GemmConfig::Gemm::Arguments args_from_options( + torch::stable::Tensor& D, torch::stable::Tensor const& A, + torch::stable::Tensor const& B, torch::stable::Tensor const& A_sf, + torch::stable::Tensor const& B_sf, torch::stable::Tensor const& alpha, + int M, int N, int K) { + static constexpr bool swap_ab = GemmConfig::swap_ab; + using Gemm = typename GemmConfig::Gemm; + using ElementA = typename Gemm::ElementA; using ElementB = typename Gemm::ElementB; using ElementD = typename Gemm::ElementD; @@ -140,22 +164,35 @@ typename Gemm::Arguments args_from_options(torch::stable::Tensor& D, using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; - auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}); - auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}); - auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1}); - - auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA( - cute::make_shape(M, N, K, 1)); - auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB( - cute::make_shape(M, N, K, 1)); + // When swap_ab, the GEMM problem becomes (N, M, K) instead of (M, N, K) + int m_eff = swap_ab ? N : M; + int n_eff = swap_ab ? M : N; + + auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {m_eff, K, 1}); + auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {n_eff, K, 1}); + auto stride_D = + cutlass::make_cute_packed_stride(StrideD{}, {m_eff, n_eff, 1}); + + auto prob_shape = cute::make_shape(m_eff, n_eff, K, 1); + auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(prob_shape); + auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(prob_shape); + + // When swap_ab: kernel's A operand gets B data, B operand gets A data, + // and scale factor pointers are swapped accordingly + auto* a_data = swap_ab ? static_cast(B.data_ptr()) + : static_cast(A.data_ptr()); + auto* b_data = swap_ab ? static_cast(A.data_ptr()) + : static_cast(B.data_ptr()); + auto* sfa_data = swap_ab ? static_cast(B_sf.data_ptr()) + : static_cast(A_sf.data_ptr()); + auto* sfb_data = swap_ab ? static_cast(A_sf.data_ptr()) + : static_cast(B_sf.data_ptr()); typename Gemm::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm, - {M, N, K, 1}, - {static_cast(A.data_ptr()), stride_A, - static_cast(B.data_ptr()), stride_B, - static_cast(A_sf.data_ptr()), layout_SFA, - static_cast(B_sf.data_ptr()), layout_SFB}, + {m_eff, n_eff, K, 1}, + {a_data, stride_A, b_data, stride_B, sfa_data, layout_SFA, sfb_data, + layout_SFB}, {{}, static_cast(D.data_ptr()), stride_D, @@ -167,15 +204,17 @@ typename Gemm::Arguments args_from_options(torch::stable::Tensor& D, return arguments; } -template +template void runGemm(torch::stable::Tensor& D, torch::stable::Tensor const& A, torch::stable::Tensor const& B, torch::stable::Tensor const& A_sf, torch::stable::Tensor const& B_sf, torch::stable::Tensor const& alpha, int M, int N, int K, cudaStream_t stream) { + using Gemm = typename GemmConfig::Gemm; Gemm gemm; - auto arguments = args_from_options(D, A, B, A_sf, B_sf, alpha, M, N, K); + auto arguments = + args_from_options(D, A, B, A_sf, B_sf, alpha, M, N, K); size_t workspace_size = Gemm::get_workspace_size(arguments); auto workspace = @@ -197,14 +236,16 @@ void cutlass_fp4_bf16_gemm_dispatch(torch::stable::Tensor& D, torch::stable::Tensor const& alpha, int m, int n, int k, cudaStream_t stream) { uint32_t const mp2 = std::max(static_cast(16), next_pow_2(m)); - if (mp2 <= 32) { - runGemm::Gemm>( - D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + if (mp2 <= 64) { + // SwapAB for small M to improve GPU utilization during decode + runGemm>(D, A, B, A_sf, B_sf, alpha, m, n, k, + stream); } else if (mp2 <= 256) { - runGemm::Gemm>( + runGemm>( D, A, B, A_sf, B_sf, alpha, m, n, k, stream); } else { - runGemm::Gemm>( + runGemm>( D, A, B, A_sf, B_sf, alpha, m, n, k, stream); } } @@ -217,14 +258,16 @@ void cutlass_fp4_f16_gemm_dispatch(torch::stable::Tensor& D, torch::stable::Tensor const& alpha, int m, int n, int k, cudaStream_t stream) { uint32_t const mp2 = std::max(static_cast(16), next_pow_2(m)); - if (mp2 <= 32) { - runGemm::Gemm>( - D, A, B, A_sf, B_sf, alpha, m, n, k, stream); + if (mp2 <= 64) { + // SwapAB for small M to improve GPU utilization during decode + runGemm>(D, A, B, A_sf, B_sf, alpha, m, n, k, + stream); } else if (mp2 <= 256) { - runGemm::Gemm>( + runGemm>( D, A, B, A_sf, B_sf, alpha, m, n, k, stream); } else { - runGemm::Gemm>( + runGemm>( D, A, B, A_sf, B_sf, alpha, m, n, k, stream); } }