Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")

# Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building.
set(CUTLASS_REVISION "v4.2.1")
set(CUTLASS_REVISION "v4.3.5")

# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
Expand Down Expand Up @@ -771,6 +771,31 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
endif()

# The MoE kernel cutlass_moe_mm for SM120
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0f" "${CUDA_ARCHS}")
else()
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0a" "${CUDA_ARCHS}")
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0 AND SCALED_MM_ARCHS)
set(SRCS "csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm120.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM120=1")
message(STATUS "Building grouped_mm_c3x_sm120 for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0 AND SCALED_MM_ARCHS)
message(STATUS "Not building grouped_mm_c3x_sm120 kernels as CUDA Compiler version is "
"not >= 13.0, we recommend upgrading to CUDA 13.0 or later "
"if you intend on running FP8 quantized MoE models on SM120.")
else()
message(STATUS "Not building grouped_mm_c3x_sm120 as no compatible archs found "
"in CUDA target architectures.")
endif()
endif()

# moe_data.cu is used by all CUTLASS MoE kernels.
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0f;11.0f;12.0f" "${CUDA_ARCHS}")
Expand Down
140 changes: 140 additions & 0 deletions csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm120.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
#include <cudaTypedefs.h>

#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>

#include "cutlass/cutlass.h"
#include "grouped_mm_c3x.cuh"

using namespace cute;

namespace {

template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm120_fp8_config_default {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
using TileShape = cute::Shape<cute::_128, cute::_256, cute::_128>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using ArchTag = cutlass::arch::Sm120;

using Cutlass3xGemm =
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
ClusterShape, KernelSchedule, EpilogueSchedule>;
};

template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm120_fp8_config_M64 {
// M in [1,64]
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
using TileShape = cute::Shape<cute::_128, cute::_16, cute::_128>;
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
using ArchTag = cutlass::arch::Sm120;

using Cutlass3xGemm =
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
ClusterShape, KernelSchedule, EpilogueSchedule,
true>;
};

template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm120_fp8_config_N8192 {
// N in [8192, inf)
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100;
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm;
using TileShape = cute::Shape<cute::_128, cute::_256, cute::_128>;
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
using ArchTag = cutlass::arch::Sm120;

using Cutlass3xGemm =
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
ClusterShape, KernelSchedule, EpilogueSchedule>;
};

template <typename InType, typename OutType>
void run_cutlass_moe_mm_sm120(
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");

TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn,
"A tensors must be of type float8_e4m3fn.");
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn,
"B tensors must be of type float8_e4m3fn.");

using Cutlass3xGemmDefault = typename sm120_fp8_config_default<
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
using Cutlass3xGemmN8192 = typename sm120_fp8_config_N8192<
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
using Cutlass3xGemmM64 = typename sm120_fp8_config_M64<
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;

uint32_t const m = a_tensors.size(0);
uint32_t const n = out_tensors.size(1);

if (m <= 64) {
cutlass_group_gemm_caller<Cutlass3xGemmM64>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else if (n >= 8192) {
cutlass_group_gemm_caller<Cutlass3xGemmN8192>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else {
cutlass_group_gemm_caller<Cutlass3xGemmDefault>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
}
}
} // namespace

void dispatch_moe_mm_sm120(
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
if (out_tensors.dtype() == torch::kBFloat16) {
run_cutlass_moe_mm_sm120<cutlass::float_e4m3_t, cutlass::bfloat16_t>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
} else {
run_cutlass_moe_mm_sm120<cutlass::float_e4m3_t, cutlass::half_t>(
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
per_out_ch);
}
}

void cutlass_moe_mm_sm120(
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
dispatch_moe_mm_sm120(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, per_act_token, per_out_ch);
}
22 changes: 20 additions & 2 deletions csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,16 @@ void cutlass_moe_mm_sm100(
bool per_act_token, bool per_out_ch);
#endif

#if defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120
void cutlass_moe_mm_sm120(
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch);
#endif

#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
Expand Down Expand Up @@ -253,8 +263,16 @@ void cutlass_moe_mm(
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
bool per_act_token, bool per_out_ch) {
int32_t version_num = get_sm_version_num();
#if defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120
if (version_num >= 120) {
cutlass_moe_mm_sm120(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, per_act_token, per_out_ch);
return;
}
#endif
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
if (version_num >= 100 && version_num < 110) {
if (version_num >= 100 && version_num < 120) {
cutlass_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
expert_offsets, problem_sizes, a_strides, b_strides,
c_strides, per_act_token, per_out_ch);
Expand All @@ -272,7 +290,7 @@ void cutlass_moe_mm(
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_scaled_mm for CUDA device capability: ", version_num,
". Required capability: 90 or 100");
". Required capability: 90, 100, or 120");
}

void get_cutlass_moe_mm_data(
Expand Down