Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 92 additions & 34 deletions csrc/libtorch_stable/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -48,32 +48,52 @@ using namespace cute;
constexpr auto FLOAT4_E2M1X2 = torch::headeronly::ScalarType::Byte;
constexpr auto SF_DTYPE = torch::headeronly::ScalarType::Float8_e4m3fn;

// Configuration for swapAB (used when M <= 64 for better GPU utilization)
// Uses larger K tile (256) for better throughput on small batch sizes
struct sm120_fp4_config_swapab {
using ClusterShape = Shape<_1, _1, _1>;
using MmaTileShape = Shape<_128, _128, _256>;
using PerSmTileShape_MNK = Shape<_128, _128, _256>;
};

// Configuration for M in (64, 256]
struct sm120_fp4_config_M256 {
using ClusterShape = Shape<_1, _1, _1>;
using MmaTileShape = Shape<_128, _128, _128>;
using PerSmTileShape_MNK = Shape<_128, _128, _128>;
};

// Configuration for M in (256, inf)
struct sm120_fp4_config_default {
using ClusterShape = Shape<_1, _1, _1>;
using MmaTileShape = Shape<_256, _128, _128>;
using PerSmTileShape_MNK = Shape<_256, _128, _128>;
};

template <typename Config, typename OutType>
template <typename Config, typename OutType, bool swap_ab_ = false>
struct Fp4GemmSm120 {
static constexpr bool swap_ab = swap_ab_;

using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
using LayoutATag = cutlass::layout::RowMajor;
using LayoutATag_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutATag>::type;
static constexpr int AlignmentA = 32;

using ElementB = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
using LayoutBTag = cutlass::layout::ColumnMajor;
using LayoutBTag_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutBTag>::type;
static constexpr int AlignmentB = 32;

using ElementD = OutType;
using ElementC = OutType;
using LayoutCTag = cutlass::layout::RowMajor;
using LayoutCTag_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutCTag>::type;
using LayoutDTag = cutlass::layout::RowMajor;
using LayoutDTag_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutDTag>::type;
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;

Expand All @@ -85,37 +105,50 @@ struct Fp4GemmSm120 {
using ClusterShape = typename Config::ClusterShape;
using PerSmTileShape_MNK = typename Config::PerSmTileShape_MNK;

// Conditionally use transposed C/D layouts when swap_ab is enabled
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass, PerSmTileShape_MNK, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
ElementAccumulator, ElementC, LayoutCTag, AlignmentC, ElementD,
LayoutDTag, AlignmentD,
ElementAccumulator, ElementC,
conditional_t<swap_ab, LayoutCTag_Transpose, LayoutCTag>, AlignmentC,
ElementD, conditional_t<swap_ab, LayoutDTag_Transpose, LayoutDTag>,
AlignmentD,
cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp;

using CollectiveMainloop =
// Conditionally swap A/B operands and their layouts in the mainloop
using CollectiveMainloop = conditional_t<
swap_ab,
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass, ElementB, LayoutBTag_Transpose, AlignmentB,
ElementA, LayoutATag_Transpose, AlignmentA, ElementAccumulator,
MmaTileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp,
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass, ElementA, LayoutATag, AlignmentA, ElementB,
LayoutBTag, AlignmentB, ElementAccumulator, MmaTileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp;
cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp>;

using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;

using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
};

template <typename Gemm>
typename Gemm::Arguments args_from_options(torch::stable::Tensor& D,
torch::stable::Tensor const& A,
torch::stable::Tensor const& B,
torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha,
int M, int N, int K) {
template <typename GemmConfig>
typename GemmConfig::Gemm::Arguments args_from_options(
torch::stable::Tensor& D, torch::stable::Tensor const& A,
torch::stable::Tensor const& B, torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf, torch::stable::Tensor const& alpha,
int M, int N, int K) {
static constexpr bool swap_ab = GemmConfig::swap_ab;
using Gemm = typename GemmConfig::Gemm;

using ElementA = typename Gemm::ElementA;
using ElementB = typename Gemm::ElementB;
using ElementD = typename Gemm::ElementD;
Expand All @@ -131,22 +164,35 @@ typename Gemm::Arguments args_from_options(torch::stable::Tensor& D,
using Sm1xxBlkScaledConfig =
typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;

auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1});
auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1});
auto stride_D = cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1});

auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(
cute::make_shape(M, N, K, 1));
auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(
cute::make_shape(M, N, K, 1));
// When swap_ab, the GEMM problem becomes (N, M, K) instead of (M, N, K)
int m_eff = swap_ab ? N : M;
int n_eff = swap_ab ? M : N;

auto stride_A = cutlass::make_cute_packed_stride(StrideA{}, {m_eff, K, 1});
auto stride_B = cutlass::make_cute_packed_stride(StrideB{}, {n_eff, K, 1});
auto stride_D =
cutlass::make_cute_packed_stride(StrideD{}, {m_eff, n_eff, 1});

auto prob_shape = cute::make_shape(m_eff, n_eff, K, 1);
auto layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(prob_shape);
auto layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(prob_shape);

// When swap_ab: kernel's A operand gets B data, B operand gets A data,
// and scale factor pointers are swapped accordingly
auto* a_data = swap_ab ? static_cast<ElementA const*>(B.data_ptr())
: static_cast<ElementA const*>(A.data_ptr());
auto* b_data = swap_ab ? static_cast<ElementB const*>(A.data_ptr())
: static_cast<ElementB const*>(B.data_ptr());
auto* sfa_data = swap_ab ? static_cast<ElementSFA const*>(B_sf.data_ptr())
: static_cast<ElementSFA const*>(A_sf.data_ptr());
auto* sfb_data = swap_ab ? static_cast<ElementSFB const*>(A_sf.data_ptr())
: static_cast<ElementSFB const*>(B_sf.data_ptr());

typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{M, N, K, 1},
{static_cast<ElementA const*>(A.data_ptr()), stride_A,
static_cast<ElementB const*>(B.data_ptr()), stride_B,
static_cast<ElementSFA const*>(A_sf.data_ptr()), layout_SFA,
static_cast<ElementSFB const*>(B_sf.data_ptr()), layout_SFB},
{m_eff, n_eff, K, 1},
{a_data, stride_A, b_data, stride_B, sfa_data, layout_SFA, sfb_data,
layout_SFB},
{{},
static_cast<ElementD const*>(D.data_ptr()),
stride_D,
Expand All @@ -158,15 +204,17 @@ typename Gemm::Arguments args_from_options(torch::stable::Tensor& D,
return arguments;
}

template <typename Gemm>
template <typename GemmConfig>
void runGemm(torch::stable::Tensor& D, torch::stable::Tensor const& A,
torch::stable::Tensor const& B, torch::stable::Tensor const& A_sf,
torch::stable::Tensor const& B_sf,
torch::stable::Tensor const& alpha, int M, int N, int K,
cudaStream_t stream) {
using Gemm = typename GemmConfig::Gemm;
Gemm gemm;

auto arguments = args_from_options<Gemm>(D, A, B, A_sf, B_sf, alpha, M, N, K);
auto arguments =
args_from_options<GemmConfig>(D, A, B, A_sf, B_sf, alpha, M, N, K);

size_t workspace_size = Gemm::get_workspace_size(arguments);
auto workspace =
Expand All @@ -188,11 +236,16 @@ void cutlass_fp4_bf16_gemm_dispatch(torch::stable::Tensor& D,
torch::stable::Tensor const& alpha, int m,
int n, int k, cudaStream_t stream) {
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
if (mp2 <= 256) {
runGemm<Fp4GemmSm120<sm120_fp4_config_M256, cutlass::bfloat16_t>::Gemm>(
if (mp2 <= 64) {
// SwapAB for small M to improve GPU utilization during decode
runGemm<Fp4GemmSm120<sm120_fp4_config_swapab, cutlass::bfloat16_t,
/*swap_ab=*/true>>(D, A, B, A_sf, B_sf, alpha, m, n, k,
stream);
} else if (mp2 <= 256) {
runGemm<Fp4GemmSm120<sm120_fp4_config_M256, cutlass::bfloat16_t>>(
D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
} else {
runGemm<Fp4GemmSm120<sm120_fp4_config_default, cutlass::bfloat16_t>::Gemm>(
runGemm<Fp4GemmSm120<sm120_fp4_config_default, cutlass::bfloat16_t>>(
D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
}
}
Expand All @@ -205,11 +258,16 @@ void cutlass_fp4_f16_gemm_dispatch(torch::stable::Tensor& D,
torch::stable::Tensor const& alpha, int m,
int n, int k, cudaStream_t stream) {
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
if (mp2 <= 256) {
runGemm<Fp4GemmSm120<sm120_fp4_config_M256, cutlass::half_t>::Gemm>(
if (mp2 <= 64) {
// SwapAB for small M to improve GPU utilization during decode
runGemm<Fp4GemmSm120<sm120_fp4_config_swapab, cutlass::half_t,
/*swap_ab=*/true>>(D, A, B, A_sf, B_sf, alpha, m, n, k,
stream);
} else if (mp2 <= 256) {
runGemm<Fp4GemmSm120<sm120_fp4_config_M256, cutlass::half_t>>(
D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
} else {
runGemm<Fp4GemmSm120<sm120_fp4_config_default, cutlass::half_t>::Gemm>(
runGemm<Fp4GemmSm120<sm120_fp4_config_default, cutlass::half_t>>(
D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
}
}
Expand Down