From 06747c3fc9804440117317713332b5ad00ce57ae Mon Sep 17 00:00:00 2001 From: Vincent Huang Date: Thu, 14 Aug 2025 22:36:29 +0000 Subject: [PATCH] Perf: no more need combine scale_a/scale_b in cutlass bmm_fp8 Previous cutlass implementation require combine of scale_a/scale_b in the python, so extra aten kernel is used which may cost 1us. Now we support separate alpha_a and alpha_b in epilog, to avoid this extra aten kernel. Signed-off-by: Vincent Huang --- csrc/fp8_gemm_cutlass.cu | 28 ++++--- flashinfer/gemm.py | 3 +- include/flashinfer/gemm/fp8_gemm_cutlass.h | 13 +-- .../gemm/fp8_gemm_cutlass_template.h | 79 ++++++++++--------- .../flashinfer/gemm/fp8_gemm_template_sm100.h | 43 +++++++--- 5 files changed, 99 insertions(+), 67 deletions(-) diff --git a/csrc/fp8_gemm_cutlass.cu b/csrc/fp8_gemm_cutlass.cu index 696d1cdd26..a26d859ee9 100644 --- a/csrc/fp8_gemm_cutlass.cu +++ b/csrc/fp8_gemm_cutlass.cu @@ -59,8 +59,8 @@ CutlassGemmConfig getFp8GemmConfig(int64_t m, int64_t n, int64_t k, int64_t tact template void runGemm(at::Tensor& out, at::Tensor const& mat1, at::Tensor const& mat2, - at::Tensor const& scale, int64_t m, int64_t n, int64_t k, int64_t b, - CutlassGemmConfig const& gemmConfig, at::Tensor workspace_buffer) { + at::Tensor const& scale_a, at::Tensor const& scale_b, int64_t m, int64_t n, int64_t k, + int64_t b, CutlassGemmConfig const& gemmConfig, at::Tensor workspace_buffer) { CutlassFp8GemmRunner gemmRunner; int64_t const required_workspace_size = gemmRunner.getWorkspaceSize(m, n, k); @@ -70,8 +70,9 @@ void runGemm(at::Tensor& out, at::Tensor const& mat1, at::Tensor const& mat2, auto runKernel = [&](void* workspace) { gemmRunner.gemm(reinterpret_cast<__nv_fp8_e4m3 const*>(mat1.const_data_ptr()), reinterpret_cast<__nv_fp8_e4m3 const*>(mat2.const_data_ptr()), - reinterpret_cast(scale.const_data_ptr()), out.data_ptr(), m, n, k, - b, gemmConfig, reinterpret_cast(workspace), required_workspace_size, + reinterpret_cast(scale_a.const_data_ptr()), + reinterpret_cast(scale_b.const_data_ptr()), out.data_ptr(), m, n, + k, b, gemmConfig, reinterpret_cast(workspace), required_workspace_size, at::cuda::getCurrentCUDAStream(mat1.get_device())); }; @@ -85,10 +86,13 @@ void runGemm(at::Tensor& out, at::Tensor const& mat1, at::Tensor const& mat2, } } -at::Tensor fp8_bmm_impl(at::Tensor const& mat1, at::Tensor const& mat2, at::Tensor const& scale, - at::Tensor out, at::Tensor workspace_buffer, int64_t tactic) { +at::Tensor fp8_bmm_impl(at::Tensor const& mat1, at::Tensor const& mat2, at::Tensor const& scale_a, + at::Tensor const& scale_b, at::Tensor out, at::Tensor workspace_buffer, + int64_t tactic) { CHECK_INPUT(mat1); CHECK_INPUT(mat2); + CHECK_INPUT(scale_a); + CHECK_INPUT(scale_b); int mat2_k_scale = 1; @@ -135,10 +139,11 @@ at::Tensor fp8_bmm_impl(at::Tensor const& mat1, at::Tensor const& mat2, at::Tens switch (out.scalar_type()) { case at::ScalarType::Half: - runGemm(out, mat1, mat2, scale, m, n, k, b, config, workspace_buffer); + runGemm(out, mat1, mat2, scale_a, scale_b, m, n, k, b, config, workspace_buffer); break; case at::ScalarType::BFloat16: - runGemm<__nv_bfloat16>(out, mat1, mat2, scale, m, n, k, b, config, workspace_buffer); + runGemm<__nv_bfloat16>(out, mat1, mat2, scale_a, scale_b, m, n, k, b, config, + workspace_buffer); break; default: TORCH_CHECK(false, "out_dtype must be one of fp16/bf16."); @@ -148,9 +153,10 @@ at::Tensor fp8_bmm_impl(at::Tensor const& mat1, at::Tensor const& mat2, at::Tens } // namespace -at::Tensor fp8_gemm(at::Tensor const& mat1, at::Tensor const& mat2, at::Tensor const& scale, - at::Tensor out, at::Tensor workspace_buffer, int64_t tactic) { - return fp8_bmm_impl(mat1, mat2, scale, out, workspace_buffer, tactic); +at::Tensor fp8_gemm(at::Tensor const& mat1, at::Tensor const& mat2, at::Tensor const& scale_a, + at::Tensor const& scale_b, at::Tensor out, at::Tensor workspace_buffer, + int64_t tactic) { + return fp8_bmm_impl(mat1, mat2, scale_a, scale_b, out, workspace_buffer, tactic); } int64_t fp8_gemm_tactic_num() { diff --git a/flashinfer/gemm.py b/flashinfer/gemm.py index 1726a34195..b59a5af894 100755 --- a/flashinfer/gemm.py +++ b/flashinfer/gemm.py @@ -406,7 +406,8 @@ def forward( module.fp8_gemm.default( a, b.transpose(-2, -1), - scale_a * scale_b, + scale_a, + scale_b, out, workspace_buffer, tactic, diff --git a/include/flashinfer/gemm/fp8_gemm_cutlass.h b/include/flashinfer/gemm/fp8_gemm_cutlass.h index 3168b66b2a..8dadac716b 100644 --- a/include/flashinfer/gemm/fp8_gemm_cutlass.h +++ b/include/flashinfer/gemm/fp8_gemm_cutlass.h @@ -31,9 +31,10 @@ class CutlassFp8GemmRunnerInterface { CutlassFp8GemmRunnerInterface() = default; virtual ~CutlassFp8GemmRunnerInterface() = default; - virtual void gemm(__nv_fp8_e4m3 const* A, __nv_fp8_e4m3 const* B, float const* alpha, void* D, - int m, int n, int k, int b, CutlassGemmConfig gemmConfig, char* workspacePtr, - size_t const workspaceBytes, cudaStream_t stream) = 0; + virtual void gemm(__nv_fp8_e4m3 const* A, __nv_fp8_e4m3 const* B, float const* scale_a, + float const* scale_b, void* D, int m, int n, int k, int b, + CutlassGemmConfig gemmConfig, char* workspacePtr, size_t const workspaceBytes, + cudaStream_t stream) = 0; virtual size_t getWorkspaceSize(int m, int n, int k) = 0; @@ -46,9 +47,9 @@ class CutlassFp8GemmRunner : public virtual CutlassFp8GemmRunnerInterface { CutlassFp8GemmRunner() = default; ~CutlassFp8GemmRunner() = default; - void gemm(__nv_fp8_e4m3 const* A, __nv_fp8_e4m3 const* B, float const* alpha, void* D, int m, - int n, int k, int b, CutlassGemmConfig gemmConfig, char* workspacePtr, - size_t const workspaceBytes, cudaStream_t stream) override; + void gemm(__nv_fp8_e4m3 const* A, __nv_fp8_e4m3 const* B, float const* scale_a, + float const* scale_b, void* D, int m, int n, int k, int b, CutlassGemmConfig gemmConfig, + char* workspacePtr, size_t const workspaceBytes, cudaStream_t stream) override; size_t getWorkspaceSize(int m, int n, int k) override; std::vector getConfigs() const override; diff --git a/include/flashinfer/gemm/fp8_gemm_cutlass_template.h b/include/flashinfer/gemm/fp8_gemm_cutlass_template.h index 90d0395834..a3c9d0f2c4 100644 --- a/include/flashinfer/gemm/fp8_gemm_cutlass_template.h +++ b/include/flashinfer/gemm/fp8_gemm_cutlass_template.h @@ -50,43 +50,44 @@ struct _2SM {}; template size_t genericFp8GemmKernelLauncherSm100(__nv_fp8_e4m3 const* A, __nv_fp8_e4m3 const* B, - float const* alpha, T* D, int m, int n, int k, int b, - CutlassGemmConfig config, char* workspacePtr, - size_t const workspaceBytes, cudaStream_t stream); + float const* scale_a, float const* scale_b, T* D, int m, + int n, int k, int b, CutlassGemmConfig config, + char* workspacePtr, size_t const workspaceBytes, + cudaStream_t stream); template size_t dispatchGemmClusterShapeSm100(__nv_fp8_e4m3 const* A, __nv_fp8_e4m3 const* B, - float const* alpha, T* D, int m, int n, int k, int b, - CutlassGemmConfig gemmConfig, char* workspacePtr, + float const* scale_a, float const* scale_b, T* D, int m, int n, + int k, int b, CutlassGemmConfig gemmConfig, char* workspacePtr, size_t const workspaceBytes, cudaStream_t stream) { using namespace cute; switch (gemmConfig.cluster_shape) { case ClusterShape::ClusterShape_1x1x1: return genericFp8GemmKernelLauncherSm100, - _1SM>(A, B, alpha, D, m, n, k, b, gemmConfig, - workspacePtr, workspaceBytes, stream); + _1SM>( + A, B, scale_a, scale_b, D, m, n, k, b, gemmConfig, workspacePtr, workspaceBytes, stream); break; case ClusterShape::ClusterShape_2x1x1: return genericFp8GemmKernelLauncherSm100, - _2SM>(A, B, alpha, D, m, n, k, b, gemmConfig, - workspacePtr, workspaceBytes, stream); + _2SM>( + A, B, scale_a, scale_b, D, m, n, k, b, gemmConfig, workspacePtr, workspaceBytes, stream); break; case ClusterShape::ClusterShape_1x2x1: return genericFp8GemmKernelLauncherSm100, - _1SM>(A, B, alpha, D, m, n, k, b, gemmConfig, - workspacePtr, workspaceBytes, stream); + _1SM>( + A, B, scale_a, scale_b, D, m, n, k, b, gemmConfig, workspacePtr, workspaceBytes, stream); break; case ClusterShape::ClusterShape_2x2x1: return genericFp8GemmKernelLauncherSm100, - _2SM>(A, B, alpha, D, m, n, k, b, gemmConfig, - workspacePtr, workspaceBytes, stream); + _2SM>( + A, B, scale_a, scale_b, D, m, n, k, b, gemmConfig, workspacePtr, workspaceBytes, stream); break; case ClusterShape::ClusterShape_1x4x1: return genericFp8GemmKernelLauncherSm100, - _1SM>(A, B, alpha, D, m, n, k, b, gemmConfig, - workspacePtr, workspaceBytes, stream); + _1SM>( + A, B, scale_a, scale_b, D, m, n, k, b, gemmConfig, workspacePtr, workspaceBytes, stream); break; default: throw std::runtime_error("invalid config for fp8 gemm"); @@ -95,9 +96,10 @@ size_t dispatchGemmClusterShapeSm100(__nv_fp8_e4m3 const* A, __nv_fp8_e4m3 const } template -size_t dispatchToArch(__nv_fp8_e4m3 const* A, __nv_fp8_e4m3 const* B, float const* alpha, void* D, - int m, int n, int k, int b, CutlassGemmConfig gemmConfig, char* workspacePtr, - size_t const workspaceBytes, cudaStream_t stream) { +size_t dispatchToArch(__nv_fp8_e4m3 const* A, __nv_fp8_e4m3 const* B, float const* scale_a, + float const* scale_b, void* D, int m, int n, int k, int b, + CutlassGemmConfig gemmConfig, char* workspacePtr, size_t const workspaceBytes, + cudaStream_t stream) { using namespace cute; using arch = cutlass::arch::Sm100; @@ -107,34 +109,34 @@ size_t dispatchToArch(__nv_fp8_e4m3 const* A, __nv_fp8_e4m3 const* B, float cons // A rowmajor, B colmajor , C, D rowmajor switch (gemmConfig.tile_config_sm100) { case CutlassTileConfigSM100::CtaShape64x64x128B: - return dispatchGemmClusterShapeSm100(B, A, alpha, static_cast(D), n, - m, k, b, gemmConfig, workspacePtr, - workspaceBytes, stream); + return dispatchGemmClusterShapeSm100( + B, A, scale_a, scale_b, static_cast(D), n, m, k, b, gemmConfig, workspacePtr, + workspaceBytes, stream); break; case CutlassTileConfigSM100::CtaShape64x128x128B: return dispatchGemmClusterShapeSm100( - B, A, alpha, static_cast(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes, - stream); + B, A, scale_a, scale_b, static_cast(D), n, m, k, b, gemmConfig, workspacePtr, + workspaceBytes, stream); break; case CutlassTileConfigSM100::CtaShape64x256x128B: return dispatchGemmClusterShapeSm100( - B, A, alpha, static_cast(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes, - stream); + B, A, scale_a, scale_b, static_cast(D), n, m, k, b, gemmConfig, workspacePtr, + workspaceBytes, stream); break; case CutlassTileConfigSM100::CtaShape128x64x128B: return dispatchGemmClusterShapeSm100( - B, A, alpha, static_cast(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes, - stream); + B, A, scale_a, scale_b, static_cast(D), n, m, k, b, gemmConfig, workspacePtr, + workspaceBytes, stream); break; case CutlassTileConfigSM100::CtaShape128x128x128B: return dispatchGemmClusterShapeSm100( - B, A, alpha, static_cast(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes, - stream); + B, A, scale_a, scale_b, static_cast(D), n, m, k, b, gemmConfig, workspacePtr, + workspaceBytes, stream); break; case CutlassTileConfigSM100::CtaShape128x256x128B: return dispatchGemmClusterShapeSm100( - B, A, alpha, static_cast(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes, - stream); + B, A, scale_a, scale_b, static_cast(D), n, m, k, b, gemmConfig, workspacePtr, + workspaceBytes, stream); break; default: @@ -145,11 +147,12 @@ size_t dispatchToArch(__nv_fp8_e4m3 const* A, __nv_fp8_e4m3 const* B, float cons template void CutlassFp8GemmRunner::gemm(__nv_fp8_e4m3 const* A, __nv_fp8_e4m3 const* B, - float const* alpha, void* D, int m, int n, int k, int b, - CutlassGemmConfig gemmConfig, char* workspacePtr, - size_t const workspaceBytes, cudaStream_t stream) { - dispatchToArch(A, B, alpha, reinterpret_cast(D), m, n, k, b, gemmConfig, workspacePtr, - workspaceBytes, stream); + float const* scale_a, float const* scale_b, void* D, int m, + int n, int k, int b, CutlassGemmConfig gemmConfig, + char* workspacePtr, size_t const workspaceBytes, + cudaStream_t stream) { + dispatchToArch(A, B, scale_a, scale_b, reinterpret_cast(D), m, n, k, b, gemmConfig, + workspacePtr, workspaceBytes, stream); } template @@ -158,8 +161,8 @@ size_t CutlassFp8GemmRunner::getWorkspaceSizeImpl(int m, int n, int k) { auto gemmConfigs = CutlassFp8GemmRunner{}.getConfigs(); for (auto const& gemmConfig : gemmConfigs) { try { - size_t curr_workspace_size = dispatchToArch(nullptr, nullptr, nullptr, nullptr, m, n, k, 1, - gemmConfig, nullptr, 0, nullptr); + size_t curr_workspace_size = dispatchToArch(nullptr, nullptr, nullptr, nullptr, nullptr, m, + n, k, 1, gemmConfig, nullptr, 0, nullptr); workspace_size = std::max(workspace_size, curr_workspace_size); } catch (std::runtime_error& e) { diff --git a/include/flashinfer/gemm/fp8_gemm_template_sm100.h b/include/flashinfer/gemm/fp8_gemm_template_sm100.h index fd35975cdd..e4db773ec2 100644 --- a/include/flashinfer/gemm/fp8_gemm_template_sm100.h +++ b/include/flashinfer/gemm/fp8_gemm_template_sm100.h @@ -23,9 +23,12 @@ #include "cutlass/arch/arch.h" #include "cutlass/cutlass.h" #include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" #include "cutlass/gemm/device/gemm_universal_adapter.h" #include "cutlass/gemm/gemm.h" +#include "cutlass/numeric_conversion.h" #include "flashinfer/arch_condition.h" #include "flashinfer/cutlass_utils.cuh" @@ -68,9 +71,10 @@ struct SMTypeAdapter<_2SM> { template size_t genericFp8GemmKernelLauncherSm100(__nv_fp8_e4m3 const* A, __nv_fp8_e4m3 const* B, - float const* alpha, T* D, int m, int n, int k, int b, - CutlassGemmConfig config, char* workspacePtr, - size_t const workspaceBytes, cudaStream_t stream) { + float const* scale_a, float const* scale_b, T* D, int m, + int n, int k, int b, CutlassGemmConfig config, + char* workspacePtr, size_t const workspaceBytes, + cudaStream_t stream) { using namespace cute; // A matrix configuration @@ -124,10 +128,23 @@ size_t genericFp8GemmKernelLauncherSm100(__nv_fp8_e4m3 const* A, __nv_fp8_e4m3 c using MainloopSchedule = typename SMTypeAdapter::MainloopSchedule; using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + using CustomEVT = cutlass::epilogue::fusion::Sm90EVT< + cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>, // scale_a * scale_b * acc + cutlass::epilogue::fusion::Sm90ScalarBroadcast, // scale_a + cutlass::epilogue::fusion::Sm90EVT< + cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>, // scale_b * acc + cutlass::epilogue::fusion::Sm90ScalarBroadcast, // scale_b + cutlass::epilogue::fusion::Sm90AccFetch // acc + >>; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - EpilogueSchedule>::CollectiveOp; + EpilogueSchedule, CustomEVT>::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, LayoutB, AlignmentB, @@ -168,11 +185,15 @@ size_t genericFp8GemmKernelLauncherSm100(__nv_fp8_e4m3 const* A, __nv_fp8_e4m3 c reinterpret_cast(D), stride_D}}; - auto& fusion_args = arguments.epilogue.thread; - fusion_args.alpha = 0.F; - fusion_args.beta = 0.F; - fusion_args.alpha_ptr = alpha; - fusion_args.beta_ptr = nullptr; + arguments.epilogue.thread = { + {{0.F}, {scale_a}}, // scale_a + { + {{0.F}, {scale_b}}, // scale_b + {}, // acc + {} // multiplies + }, + {} // multiplies + }; Gemm gemm; @@ -210,8 +231,8 @@ size_t genericFp8GemmKernelLauncherSm100(__nv_fp8_e4m3 const* A, __nv_fp8_e4m3 c template size_t genericFp8GemmKernelLauncherSm100< \ RET_TYPE, cutlass::arch::Sm100, TILE_M, TILE_N, TILE_K, \ cute::Shape, cute::Int, cute::Int>, SM_TYPE>( \ - __nv_fp8_e4m3 const* A, __nv_fp8_e4m3 const* B, float const* alpha, RET_TYPE* D, int m, \ - int n, int k, int b, CutlassGemmConfig config, char* workspacePtr, \ + __nv_fp8_e4m3 const* A, __nv_fp8_e4m3 const* B, float const* scale_a, float const* scale_b, \ + RET_TYPE* D, int m, int n, int k, int b, CutlassGemmConfig config, char* workspacePtr, \ size_t const workspaceBytes, cudaStream_t stream); #endif // FLASHINFER_FP8_GEMM_TEMPLATE_SM100_H_