Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 17 additions & 11 deletions csrc/fp8_gemm_cutlass.cu
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ CutlassGemmConfig getFp8GemmConfig(int64_t m, int64_t n, int64_t k, int64_t tact

template <typename T>
void runGemm(at::Tensor& out, at::Tensor const& mat1, at::Tensor const& mat2,
at::Tensor const& scale, int64_t m, int64_t n, int64_t k, int64_t b,
CutlassGemmConfig const& gemmConfig, at::Tensor workspace_buffer) {
at::Tensor const& scale_a, at::Tensor const& scale_b, int64_t m, int64_t n, int64_t k,
int64_t b, CutlassGemmConfig const& gemmConfig, at::Tensor workspace_buffer) {
CutlassFp8GemmRunner<T> gemmRunner;

int64_t const required_workspace_size = gemmRunner.getWorkspaceSize(m, n, k);
Expand All @@ -70,8 +70,9 @@ void runGemm(at::Tensor& out, at::Tensor const& mat1, at::Tensor const& mat2,
auto runKernel = [&](void* workspace) {
gemmRunner.gemm(reinterpret_cast<__nv_fp8_e4m3 const*>(mat1.const_data_ptr()),
reinterpret_cast<__nv_fp8_e4m3 const*>(mat2.const_data_ptr()),
reinterpret_cast<float const*>(scale.const_data_ptr()), out.data_ptr(), m, n, k,
b, gemmConfig, reinterpret_cast<char*>(workspace), required_workspace_size,
reinterpret_cast<float const*>(scale_a.const_data_ptr()),
reinterpret_cast<float const*>(scale_b.const_data_ptr()), out.data_ptr(), m, n,
k, b, gemmConfig, reinterpret_cast<char*>(workspace), required_workspace_size,
at::cuda::getCurrentCUDAStream(mat1.get_device()));
};

Expand All @@ -85,10 +86,13 @@ void runGemm(at::Tensor& out, at::Tensor const& mat1, at::Tensor const& mat2,
}
}

at::Tensor fp8_bmm_impl(at::Tensor const& mat1, at::Tensor const& mat2, at::Tensor const& scale,
at::Tensor out, at::Tensor workspace_buffer, int64_t tactic) {
at::Tensor fp8_bmm_impl(at::Tensor const& mat1, at::Tensor const& mat2, at::Tensor const& scale_a,
at::Tensor const& scale_b, at::Tensor out, at::Tensor workspace_buffer,
int64_t tactic) {
CHECK_INPUT(mat1);
CHECK_INPUT(mat2);
CHECK_INPUT(scale_a);
CHECK_INPUT(scale_b);

int mat2_k_scale = 1;

Expand Down Expand Up @@ -135,10 +139,11 @@ at::Tensor fp8_bmm_impl(at::Tensor const& mat1, at::Tensor const& mat2, at::Tens

switch (out.scalar_type()) {
case at::ScalarType::Half:
runGemm<half>(out, mat1, mat2, scale, m, n, k, b, config, workspace_buffer);
runGemm<half>(out, mat1, mat2, scale_a, scale_b, m, n, k, b, config, workspace_buffer);
break;
case at::ScalarType::BFloat16:
runGemm<__nv_bfloat16>(out, mat1, mat2, scale, m, n, k, b, config, workspace_buffer);
runGemm<__nv_bfloat16>(out, mat1, mat2, scale_a, scale_b, m, n, k, b, config,
workspace_buffer);
break;
default:
TORCH_CHECK(false, "out_dtype must be one of fp16/bf16.");
Expand All @@ -148,9 +153,10 @@ at::Tensor fp8_bmm_impl(at::Tensor const& mat1, at::Tensor const& mat2, at::Tens

} // namespace

at::Tensor fp8_gemm(at::Tensor const& mat1, at::Tensor const& mat2, at::Tensor const& scale,
at::Tensor out, at::Tensor workspace_buffer, int64_t tactic) {
return fp8_bmm_impl(mat1, mat2, scale, out, workspace_buffer, tactic);
at::Tensor fp8_gemm(at::Tensor const& mat1, at::Tensor const& mat2, at::Tensor const& scale_a,
at::Tensor const& scale_b, at::Tensor out, at::Tensor workspace_buffer,
int64_t tactic) {
return fp8_bmm_impl(mat1, mat2, scale_a, scale_b, out, workspace_buffer, tactic);
}

int64_t fp8_gemm_tactic_num() {
Expand Down
3 changes: 2 additions & 1 deletion flashinfer/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,8 @@ def forward(
module.fp8_gemm.default(
a,
b.transpose(-2, -1),
scale_a * scale_b,
scale_a,
scale_b,
out,
workspace_buffer,
tactic,
Expand Down
13 changes: 7 additions & 6 deletions include/flashinfer/gemm/fp8_gemm_cutlass.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ class CutlassFp8GemmRunnerInterface {
CutlassFp8GemmRunnerInterface() = default;
virtual ~CutlassFp8GemmRunnerInterface() = default;

virtual void gemm(__nv_fp8_e4m3 const* A, __nv_fp8_e4m3 const* B, float const* alpha, void* D,
int m, int n, int k, int b, CutlassGemmConfig gemmConfig, char* workspacePtr,
size_t const workspaceBytes, cudaStream_t stream) = 0;
virtual void gemm(__nv_fp8_e4m3 const* A, __nv_fp8_e4m3 const* B, float const* scale_a,
float const* scale_b, void* D, int m, int n, int k, int b,
CutlassGemmConfig gemmConfig, char* workspacePtr, size_t const workspaceBytes,
cudaStream_t stream) = 0;

virtual size_t getWorkspaceSize(int m, int n, int k) = 0;

Expand All @@ -46,9 +47,9 @@ class CutlassFp8GemmRunner : public virtual CutlassFp8GemmRunnerInterface {
CutlassFp8GemmRunner() = default;
~CutlassFp8GemmRunner() = default;

void gemm(__nv_fp8_e4m3 const* A, __nv_fp8_e4m3 const* B, float const* alpha, void* D, int m,
int n, int k, int b, CutlassGemmConfig gemmConfig, char* workspacePtr,
size_t const workspaceBytes, cudaStream_t stream) override;
void gemm(__nv_fp8_e4m3 const* A, __nv_fp8_e4m3 const* B, float const* scale_a,
float const* scale_b, void* D, int m, int n, int k, int b, CutlassGemmConfig gemmConfig,
char* workspacePtr, size_t const workspaceBytes, cudaStream_t stream) override;
size_t getWorkspaceSize(int m, int n, int k) override;
std::vector<CutlassGemmConfig> getConfigs() const override;

Expand Down
79 changes: 41 additions & 38 deletions include/flashinfer/gemm/fp8_gemm_cutlass_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,43 +50,44 @@ struct _2SM {};
template <typename T, typename arch, int32_t CTA_M_, int32_t CTA_N_, int32_t CTA_K_,
typename ClusterShape_, typename XSM_>
size_t genericFp8GemmKernelLauncherSm100(__nv_fp8_e4m3 const* A, __nv_fp8_e4m3 const* B,
float const* alpha, T* D, int m, int n, int k, int b,
CutlassGemmConfig config, char* workspacePtr,
size_t const workspaceBytes, cudaStream_t stream);
float const* scale_a, float const* scale_b, T* D, int m,
int n, int k, int b, CutlassGemmConfig config,
char* workspacePtr, size_t const workspaceBytes,
cudaStream_t stream);

template <typename T, typename arch, int32_t CTA_M_, int32_t CTA_N_, int32_t CTA_K_>
size_t dispatchGemmClusterShapeSm100(__nv_fp8_e4m3 const* A, __nv_fp8_e4m3 const* B,
float const* alpha, T* D, int m, int n, int k, int b,
CutlassGemmConfig gemmConfig, char* workspacePtr,
float const* scale_a, float const* scale_b, T* D, int m, int n,
int k, int b, CutlassGemmConfig gemmConfig, char* workspacePtr,
size_t const workspaceBytes, cudaStream_t stream) {
using namespace cute;

switch (gemmConfig.cluster_shape) {
case ClusterShape::ClusterShape_1x1x1:
return genericFp8GemmKernelLauncherSm100<T, arch, CTA_M_, CTA_N_, CTA_K_, Shape<_1, _1, _1>,
_1SM>(A, B, alpha, D, m, n, k, b, gemmConfig,
workspacePtr, workspaceBytes, stream);
_1SM>(
A, B, scale_a, scale_b, D, m, n, k, b, gemmConfig, workspacePtr, workspaceBytes, stream);
break;

case ClusterShape::ClusterShape_2x1x1:
return genericFp8GemmKernelLauncherSm100<T, arch, CTA_M_, CTA_N_, CTA_K_, Shape<_2, _1, _1>,
_2SM>(A, B, alpha, D, m, n, k, b, gemmConfig,
workspacePtr, workspaceBytes, stream);
_2SM>(
A, B, scale_a, scale_b, D, m, n, k, b, gemmConfig, workspacePtr, workspaceBytes, stream);
break;
case ClusterShape::ClusterShape_1x2x1:
return genericFp8GemmKernelLauncherSm100<T, arch, CTA_M_, CTA_N_, CTA_K_, Shape<_1, _2, _1>,
_1SM>(A, B, alpha, D, m, n, k, b, gemmConfig,
workspacePtr, workspaceBytes, stream);
_1SM>(
A, B, scale_a, scale_b, D, m, n, k, b, gemmConfig, workspacePtr, workspaceBytes, stream);
break;
case ClusterShape::ClusterShape_2x2x1:
return genericFp8GemmKernelLauncherSm100<T, arch, CTA_M_, CTA_N_, CTA_K_, Shape<_2, _2, _1>,
_2SM>(A, B, alpha, D, m, n, k, b, gemmConfig,
workspacePtr, workspaceBytes, stream);
_2SM>(
A, B, scale_a, scale_b, D, m, n, k, b, gemmConfig, workspacePtr, workspaceBytes, stream);
break;
case ClusterShape::ClusterShape_1x4x1:
return genericFp8GemmKernelLauncherSm100<T, arch, CTA_M_, CTA_N_, CTA_K_, Shape<_1, _4, _1>,
_1SM>(A, B, alpha, D, m, n, k, b, gemmConfig,
workspacePtr, workspaceBytes, stream);
_1SM>(
A, B, scale_a, scale_b, D, m, n, k, b, gemmConfig, workspacePtr, workspaceBytes, stream);
break;
default:
throw std::runtime_error("invalid config for fp8 gemm");
Expand All @@ -95,9 +96,10 @@ size_t dispatchGemmClusterShapeSm100(__nv_fp8_e4m3 const* A, __nv_fp8_e4m3 const
}

template <typename T>
size_t dispatchToArch(__nv_fp8_e4m3 const* A, __nv_fp8_e4m3 const* B, float const* alpha, void* D,
int m, int n, int k, int b, CutlassGemmConfig gemmConfig, char* workspacePtr,
size_t const workspaceBytes, cudaStream_t stream) {
size_t dispatchToArch(__nv_fp8_e4m3 const* A, __nv_fp8_e4m3 const* B, float const* scale_a,
float const* scale_b, void* D, int m, int n, int k, int b,
CutlassGemmConfig gemmConfig, char* workspacePtr, size_t const workspaceBytes,
cudaStream_t stream) {
using namespace cute;

using arch = cutlass::arch::Sm100;
Expand All @@ -107,34 +109,34 @@ size_t dispatchToArch(__nv_fp8_e4m3 const* A, __nv_fp8_e4m3 const* B, float cons
// A rowmajor, B colmajor , C, D rowmajor
switch (gemmConfig.tile_config_sm100) {
case CutlassTileConfigSM100::CtaShape64x64x128B:
return dispatchGemmClusterShapeSm100<T, arch, 64, 64, 128>(B, A, alpha, static_cast<T*>(D), n,
m, k, b, gemmConfig, workspacePtr,
workspaceBytes, stream);
return dispatchGemmClusterShapeSm100<T, arch, 64, 64, 128>(
B, A, scale_a, scale_b, static_cast<T*>(D), n, m, k, b, gemmConfig, workspacePtr,
workspaceBytes, stream);
break;
case CutlassTileConfigSM100::CtaShape64x128x128B:
return dispatchGemmClusterShapeSm100<T, arch, 64, 128, 128>(
B, A, alpha, static_cast<T*>(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes,
stream);
B, A, scale_a, scale_b, static_cast<T*>(D), n, m, k, b, gemmConfig, workspacePtr,
workspaceBytes, stream);
break;
case CutlassTileConfigSM100::CtaShape64x256x128B:
return dispatchGemmClusterShapeSm100<T, arch, 64, 256, 128>(
B, A, alpha, static_cast<T*>(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes,
stream);
B, A, scale_a, scale_b, static_cast<T*>(D), n, m, k, b, gemmConfig, workspacePtr,
workspaceBytes, stream);
break;
case CutlassTileConfigSM100::CtaShape128x64x128B:
return dispatchGemmClusterShapeSm100<T, arch, 128, 64, 128>(
B, A, alpha, static_cast<T*>(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes,
stream);
B, A, scale_a, scale_b, static_cast<T*>(D), n, m, k, b, gemmConfig, workspacePtr,
workspaceBytes, stream);
break;
case CutlassTileConfigSM100::CtaShape128x128x128B:
return dispatchGemmClusterShapeSm100<T, arch, 128, 128, 128>(
B, A, alpha, static_cast<T*>(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes,
stream);
B, A, scale_a, scale_b, static_cast<T*>(D), n, m, k, b, gemmConfig, workspacePtr,
workspaceBytes, stream);
break;
case CutlassTileConfigSM100::CtaShape128x256x128B:
return dispatchGemmClusterShapeSm100<T, arch, 128, 256, 128>(
B, A, alpha, static_cast<T*>(D), n, m, k, b, gemmConfig, workspacePtr, workspaceBytes,
stream);
B, A, scale_a, scale_b, static_cast<T*>(D), n, m, k, b, gemmConfig, workspacePtr,
workspaceBytes, stream);
break;

default:
Expand All @@ -145,11 +147,12 @@ size_t dispatchToArch(__nv_fp8_e4m3 const* A, __nv_fp8_e4m3 const* B, float cons

template <typename T>
void CutlassFp8GemmRunner<T>::gemm(__nv_fp8_e4m3 const* A, __nv_fp8_e4m3 const* B,
float const* alpha, void* D, int m, int n, int k, int b,
CutlassGemmConfig gemmConfig, char* workspacePtr,
size_t const workspaceBytes, cudaStream_t stream) {
dispatchToArch<T>(A, B, alpha, reinterpret_cast<T*>(D), m, n, k, b, gemmConfig, workspacePtr,
workspaceBytes, stream);
float const* scale_a, float const* scale_b, void* D, int m,
int n, int k, int b, CutlassGemmConfig gemmConfig,
char* workspacePtr, size_t const workspaceBytes,
cudaStream_t stream) {
dispatchToArch<T>(A, B, scale_a, scale_b, reinterpret_cast<T*>(D), m, n, k, b, gemmConfig,
workspacePtr, workspaceBytes, stream);
}

template <typename T>
Expand All @@ -158,8 +161,8 @@ size_t CutlassFp8GemmRunner<T>::getWorkspaceSizeImpl(int m, int n, int k) {
auto gemmConfigs = CutlassFp8GemmRunner<T>{}.getConfigs();
for (auto const& gemmConfig : gemmConfigs) {
try {
size_t curr_workspace_size = dispatchToArch<T>(nullptr, nullptr, nullptr, nullptr, m, n, k, 1,
gemmConfig, nullptr, 0, nullptr);
size_t curr_workspace_size = dispatchToArch<T>(nullptr, nullptr, nullptr, nullptr, nullptr, m,
n, k, 1, gemmConfig, nullptr, 0, nullptr);

workspace_size = std::max(workspace_size, curr_workspace_size);
} catch (std::runtime_error& e) {
Expand Down
43 changes: 32 additions & 11 deletions include/flashinfer/gemm/fp8_gemm_template_sm100.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@
#include "cutlass/arch/arch.h"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp"
#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/numeric_conversion.h"
#include "flashinfer/arch_condition.h"
#include "flashinfer/cutlass_utils.cuh"

Expand Down Expand Up @@ -68,9 +71,10 @@ struct SMTypeAdapter<_2SM> {
template <typename T, typename arch, int32_t CTA_M_, int32_t CTA_N_, int32_t CTA_K_,
typename ClusterShape_, typename XSM_>
size_t genericFp8GemmKernelLauncherSm100(__nv_fp8_e4m3 const* A, __nv_fp8_e4m3 const* B,
float const* alpha, T* D, int m, int n, int k, int b,
CutlassGemmConfig config, char* workspacePtr,
size_t const workspaceBytes, cudaStream_t stream) {
float const* scale_a, float const* scale_b, T* D, int m,
int n, int k, int b, CutlassGemmConfig config,
char* workspacePtr, size_t const workspaceBytes,
cudaStream_t stream) {
using namespace cute;

// A matrix configuration
Expand Down Expand Up @@ -124,10 +128,23 @@ size_t genericFp8GemmKernelLauncherSm100(__nv_fp8_e4m3 const* A, __nv_fp8_e4m3 c
using MainloopSchedule = typename SMTypeAdapter<XSM_>::MainloopSchedule;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;

using CustomEVT = cutlass::epilogue::fusion::Sm90EVT<
cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>, // scale_a * scale_b * acc
cutlass::epilogue::fusion::Sm90ScalarBroadcast<float>, // scale_a
cutlass::epilogue::fusion::Sm90EVT<
cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>, // scale_b * acc
cutlass::epilogue::fusion::Sm90ScalarBroadcast<float>, // scale_b
cutlass::epilogue::fusion::Sm90AccFetch // acc
>>;

using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, ElementAccumulator,
ElementCompute, ElementC, LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD,
EpilogueSchedule>::CollectiveOp;
EpilogueSchedule, CustomEVT>::CollectiveOp;

using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, LayoutB, AlignmentB,
Expand Down Expand Up @@ -168,11 +185,15 @@ size_t genericFp8GemmKernelLauncherSm100(__nv_fp8_e4m3 const* A, __nv_fp8_e4m3 c
reinterpret_cast<ElementOutput*>(D),
stride_D}};

auto& fusion_args = arguments.epilogue.thread;
fusion_args.alpha = 0.F;
fusion_args.beta = 0.F;
fusion_args.alpha_ptr = alpha;
fusion_args.beta_ptr = nullptr;
arguments.epilogue.thread = {
{{0.F}, {scale_a}}, // scale_a
{
{{0.F}, {scale_b}}, // scale_b
{}, // acc
{} // multiplies
},
{} // multiplies
};

Gemm gemm;

Expand Down Expand Up @@ -210,8 +231,8 @@ size_t genericFp8GemmKernelLauncherSm100(__nv_fp8_e4m3 const* A, __nv_fp8_e4m3 c
template size_t genericFp8GemmKernelLauncherSm100< \
RET_TYPE, cutlass::arch::Sm100, TILE_M, TILE_N, TILE_K, \
cute::Shape<cute::Int<CGA_M_>, cute::Int<CGA_N_>, cute::Int<CGA_K_>>, SM_TYPE>( \
__nv_fp8_e4m3 const* A, __nv_fp8_e4m3 const* B, float const* alpha, RET_TYPE* D, int m, \
int n, int k, int b, CutlassGemmConfig config, char* workspacePtr, \
__nv_fp8_e4m3 const* A, __nv_fp8_e4m3 const* B, float const* scale_a, float const* scale_b, \
RET_TYPE* D, int m, int n, int k, int b, CutlassGemmConfig config, char* workspacePtr, \
size_t const workspaceBytes, cudaStream_t stream);

#endif // FLASHINFER_FP8_GEMM_TEMPLATE_SM100_H_