diff --git a/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu b/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu index b500ae5a0a74..66090ea5c5fb 100644 --- a/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu +++ b/csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu @@ -48,32 +48,52 @@ using namespace cute; constexpr auto FLOAT4_E2M1X2 = torch::headeronly::ScalarType::Byte; constexpr auto SF_DTYPE = torch::headeronly::ScalarType::Float8_e4m3fn; +// Configuration for swapAB (used when M <= 64 for better GPU utilization) +// Uses larger K tile (256) for better throughput on small batch sizes +struct sm120_fp4_config_swapab { + using ClusterShape = Shape<_1, _1, _1>; + using MmaTileShape = Shape<_128, _128, _256>; + using PerSmTileShape_MNK = Shape<_128, _128, _256>; +}; + +// Configuration for M in (64, 256] struct sm120_fp4_config_M256 { using ClusterShape = Shape<_1, _1, _1>; using MmaTileShape = Shape<_128, _128, _128>; using PerSmTileShape_MNK = Shape<_128, _128, _128>; }; +// Configuration for M in (256, inf) struct sm120_fp4_config_default { using ClusterShape = Shape<_1, _1, _1>; using MmaTileShape = Shape<_256, _128, _128>; using PerSmTileShape_MNK = Shape<_256, _128, _128>; }; -template +template struct Fp4GemmSm120 { + static constexpr bool swap_ab = swap_ab_; + using ElementA = cutlass::nv_float4_t; using LayoutATag = cutlass::layout::RowMajor; + using LayoutATag_Transpose = + typename cutlass::layout::LayoutTranspose::type; static constexpr int AlignmentA = 32; using ElementB = cutlass::nv_float4_t; using LayoutBTag = cutlass::layout::ColumnMajor; + using LayoutBTag_Transpose = + typename cutlass::layout::LayoutTranspose::type; static constexpr int AlignmentB = 32; using ElementD = OutType; using ElementC = OutType; using LayoutCTag = cutlass::layout::RowMajor; + using LayoutCTag_Transpose = + typename cutlass::layout::LayoutTranspose::type; using LayoutDTag = cutlass::layout::RowMajor; + using LayoutDTag_Transpose = + typename cutlass::layout::LayoutTranspose::type; static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; @@ -85,22 +105,34 @@ struct Fp4GemmSm120 { using ClusterShape = typename Config::ClusterShape; using PerSmTileShape_MNK = typename Config::PerSmTileShape_MNK; + // Conditionally use transposed C/D layouts when swap_ab is enabled using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, PerSmTileShape_MNK, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, - ElementAccumulator, ElementC, LayoutCTag, AlignmentC, ElementD, - LayoutDTag, AlignmentD, + ElementAccumulator, ElementC, + conditional_t, AlignmentC, + ElementD, conditional_t, + AlignmentD, cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp; - using CollectiveMainloop = + // Conditionally swap A/B operands and their layouts in the mainloop + using CollectiveMainloop = conditional_t< + swap_ab, + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementB, LayoutBTag_Transpose, AlignmentB, + ElementA, LayoutATag_Transpose, AlignmentA, ElementAccumulator, + MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp, typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, OperatorClass, ElementA, LayoutATag, AlignmentA, ElementB, LayoutBTag, AlignmentB, ElementAccumulator, MmaTileShape, ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout( sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp; + cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp>; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, CollectiveMainloop, CollectiveEpilogue, void>; @@ -108,14 +140,15 @@ struct Fp4GemmSm120 { using Gemm = cutlass::gemm::device::GemmUniversalAdapter; }; -template -typename Gemm::Arguments args_from_options(torch::stable::Tensor& D, - torch::stable::Tensor const& A, - torch::stable::Tensor const& B, - torch::stable::Tensor const& A_sf, - torch::stable::Tensor const& B_sf, - torch::stable::Tensor const& alpha, - int M, int N, int K) { +template +typename GemmConfig::Gemm::Arguments args_from_options( + torch::stable::Tensor& D, torch::stable::Tensor const& A, + torch::stable::Tensor const& B, torch::stable::Tensor const& A_sf, + torch::stable::Tensor const& B_sf, torch::stable::Tensor const& alpha, + int M, int N, int K) { + static constexpr bool swap_ab = GemmConfig::swap_ab; + using Gemm = typename GemmConfig::Gemm; + using ElementA = typename Gemm::ElementA; using ElementB = typename Gemm::ElementB; using ElementD = typename Gemm::ElementD; @@ -131,22 +164,35 @@ typename Gemm::Arguments args_from_options(torch::stable::Tensor& D, using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; - auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1}); - auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1}); - auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1}); - - auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA( - cute::make_shape(M, N, K, 1)); - auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB( - cute::make_shape(M, N, K, 1)); + // When swap_ab, the GEMM problem becomes (N, M, K) instead of (M, N, K) + int m_eff = swap_ab ? N : M; + int n_eff = swap_ab ? M : N; + + auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {m_eff, K, 1}); + auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {n_eff, K, 1}); + auto stride_D = + cutlass::make_cute_packed_stride(StrideD{}, {m_eff, n_eff, 1}); + + auto prob_shape = cute::make_shape(m_eff, n_eff, K, 1); + auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(prob_shape); + auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(prob_shape); + + // When swap_ab: kernel's A operand gets B data, B operand gets A data, + // and scale factor pointers are swapped accordingly + auto* a_data = swap_ab ? static_cast(B.data_ptr()) + : static_cast(A.data_ptr()); + auto* b_data = swap_ab ? static_cast(A.data_ptr()) + : static_cast(B.data_ptr()); + auto* sfa_data = swap_ab ? static_cast(B_sf.data_ptr()) + : static_cast(A_sf.data_ptr()); + auto* sfb_data = swap_ab ? static_cast(A_sf.data_ptr()) + : static_cast(B_sf.data_ptr()); typename Gemm::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGemm, - {M, N, K, 1}, - {static_cast(A.data_ptr()), stride_A, - static_cast(B.data_ptr()), stride_B, - static_cast(A_sf.data_ptr()), layout_SFA, - static_cast(B_sf.data_ptr()), layout_SFB}, + {m_eff, n_eff, K, 1}, + {a_data, stride_A, b_data, stride_B, sfa_data, layout_SFA, sfb_data, + layout_SFB}, {{}, static_cast(D.data_ptr()), stride_D, @@ -158,15 +204,17 @@ typename Gemm::Arguments args_from_options(torch::stable::Tensor& D, return arguments; } -template +template void runGemm(torch::stable::Tensor& D, torch::stable::Tensor const& A, torch::stable::Tensor const& B, torch::stable::Tensor const& A_sf, torch::stable::Tensor const& B_sf, torch::stable::Tensor const& alpha, int M, int N, int K, cudaStream_t stream) { + using Gemm = typename GemmConfig::Gemm; Gemm gemm; - auto arguments = args_from_options(D, A, B, A_sf, B_sf, alpha, M, N, K); + auto arguments = + args_from_options(D, A, B, A_sf, B_sf, alpha, M, N, K); size_t workspace_size = Gemm::get_workspace_size(arguments); auto workspace = @@ -188,11 +236,16 @@ void cutlass_fp4_bf16_gemm_dispatch(torch::stable::Tensor& D, torch::stable::Tensor const& alpha, int m, int n, int k, cudaStream_t stream) { uint32_t const mp2 = std::max(static_cast(16), next_pow_2(m)); - if (mp2 <= 256) { - runGemm::Gemm>( + if (mp2 <= 64) { + // SwapAB for small M to improve GPU utilization during decode + runGemm>(D, A, B, A_sf, B_sf, alpha, m, n, k, + stream); + } else if (mp2 <= 256) { + runGemm>( D, A, B, A_sf, B_sf, alpha, m, n, k, stream); } else { - runGemm::Gemm>( + runGemm>( D, A, B, A_sf, B_sf, alpha, m, n, k, stream); } } @@ -205,11 +258,16 @@ void cutlass_fp4_f16_gemm_dispatch(torch::stable::Tensor& D, torch::stable::Tensor const& alpha, int m, int n, int k, cudaStream_t stream) { uint32_t const mp2 = std::max(static_cast(16), next_pow_2(m)); - if (mp2 <= 256) { - runGemm::Gemm>( + if (mp2 <= 64) { + // SwapAB for small M to improve GPU utilization during decode + runGemm>(D, A, B, A_sf, B_sf, alpha, m, n, k, + stream); + } else if (mp2 <= 256) { + runGemm>( D, A, B, A_sf, B_sf, alpha, m, n, k, stream); } else { - runGemm::Gemm>( + runGemm>( D, A, B, A_sf, B_sf, alpha, m, n, k, stream); } }