Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set(SRCS
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu"
)
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
Expand Down
10 changes: 10 additions & 0 deletions csrc/cutlass_extensions/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,13 @@ struct enable_sm90_only : Kernel {
#endif
}
};

template <typename Kernel>
struct enable_sm100_only : Kernel {
template <typename... Args>
CUTLASS_DEVICE void operator()(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 1000
Kernel::operator()(std::forward<Args>(args)...);
#endif
}
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_blockwise_sm100_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"

namespace vllm {

void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
TORCH_CHECK(
a.size(0) % 4 == 0,
"Input tensor must have a number of rows that is a multiple of 4. ",
"but got: ", a.size(0), " rows.");
if (out.dtype() == torch::kBFloat16) {
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::bfloat16_t>(
out, a, b, a_scales, b_scales);

} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
cutlass_gemm_blockwise_sm100_fp8_dispatch<cutlass::half_t>(
out, a, b, a_scales, b_scales);
}
}

} // namespace vllm
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
#pragma once

#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"

#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"

#include "cutlass_extensions/gemm/dispatch_policy.hpp"
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"

#include "cutlass_gemm_caller.cuh"

namespace vllm {

using namespace cute;

template <typename OutType, typename MmaTileShape, typename ScalesPerTile,
class ClusterShape, typename EpilogueScheduler,
typename MainloopScheduler>
struct cutlass_3x_gemm_fp8_blockwise {
using ElementAB = cutlass::float_e4m3_t;

using ElementA = ElementAB;
using LayoutA = cutlass::layout::RowMajor;
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;

using ElementB = ElementAB;
using LayoutB = cutlass::layout::ColumnMajor;
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;

using ElementC = void;
using ElementD = OutType;
using LayoutD = cutlass::layout::RowMajor;
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;

using LayoutC = LayoutD;
static constexpr int AlignmentC = AlignmentD;

using ElementAccumulator = float;
using ElementCompute = float;
using ElementBlockScale = float;

// MMA and Cluster Tile Shapes
// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster
// Shape %2 == 0 using MmaTileShape_MNK = Shape<_128,_128,_128>;
static constexpr int ScaleMsPerTile = size<0>(ScalesPerTile{});
static constexpr int ScaleGranularityM =
size<0>(MmaTileShape{}) / ScaleMsPerTile;
static constexpr int ScaleGranularityN =
size<1>(MmaTileShape{}) / size<1>(ScalesPerTile{});
static constexpr int ScaleGranularityK =
size<2>(MmaTileShape{}) / size<2>(ScalesPerTile{});

// Shape of the threadblocks in a cluster
using ClusterShape_MNK = ClusterShape;

using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig<
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
cute::UMMA::Major::MN, cute::UMMA::Major::K>;
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());

using ArchTag = cutlass::arch::Sm100;
using OperatorClass = cutlass::arch::OpClassTensorOp;

static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
using ElementScalar = float;
// clang-format off
using DefaultOperation = cutlass::epilogue::fusion::LinearCombination<ElementD, ElementCompute, ElementC, ElementScalar, RoundStyle>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
MmaTileShape,
ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto,
ElementAccumulator,
ElementCompute,
ElementC,
LayoutC,
AlignmentC,
ElementD,
LayoutD,
AlignmentD,
EpilogueScheduler,
DefaultOperation
>::CollectiveOp;

using StageCountType = cutlass::gemm::collective::StageCountAuto;
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
ElementA,
cute::tuple<LayoutA, LayoutSFA>,
AlignmentA,
ElementB,
cute::tuple<LayoutB, LayoutSFB>,
AlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,

cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
MainloopScheduler
>::CollectiveOp;
// clang-format on

using KernelType = enable_sm100_only<cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue>>;

struct GemmKernel : public KernelType {};
};

template <typename Gemm>
void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
using GemmKernel = typename Gemm::GemmKernel;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideD = typename Gemm::GemmKernel::StrideD;
using StrideC = typename Gemm::GemmKernel::StrideC;
using LayoutSFA = typename Gemm::LayoutSFA;
using LayoutSFB = typename Gemm::LayoutSFB;
using ScaleConfig = typename Gemm::ScaleConfig;

using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD;

int32_t m = a.size(0), n = b.size(1), k = a.size(1);
auto prob_shape = cute::make_shape(m, n, k, 1);

StrideA a_stride;
StrideB b_stride;
StrideC c_stride;
a_stride =
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
b_stride =
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
c_stride =
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1));

LayoutSFA layout_SFA =
ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1));
LayoutSFB layout_SFB =
ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));

auto a_ptr = static_cast<ElementAB*>(a.data_ptr());
auto b_ptr = static_cast<ElementAB*>(b.data_ptr());
auto a_scales_ptr = static_cast<float*>(a_scales.data_ptr());
auto b_scales_ptr = static_cast<float*>(b_scales.data_ptr());

typename GemmKernel::MainloopArguments mainloop_args{
a_ptr, a_stride, b_ptr, b_stride,
a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB};

auto c_ptr = static_cast<ElementD*>(out.data_ptr());
typename GemmKernel::EpilogueArguments epilogue_args{
{}, c_ptr, c_stride, c_ptr, c_stride};
c3x::cutlass_gemm_caller<GemmKernel>(a.device(), prob_shape, mainloop_args,
epilogue_args);
}

template <typename OutType>
void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
auto m = a.size(0);
auto k = a.size(1);
auto n = b.size(1);
int sms;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device());

auto should_use_2sm = [&sms](int m, int n, int tile1SM = 128) {
return std::ceil(static_cast<float>(m) / tile1SM) *
std::ceil(static_cast<float>(n) / tile1SM) >=
sms;
};
bool use_2sm = should_use_2sm(m, n);
if (use_2sm) {
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
OutType, Shape<_256, _128, _128>, Shape<_256, _1, _1>,
Shape<_2, _2, _1>, cutlass::epilogue::TmaWarpSpecialized2Sm,
cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>(
out, a, b, a_scales, b_scales);
} else {
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
OutType, Shape<_128, _128, _128>, Shape<_128, _1, _1>,
Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm,
cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>(
out, a, b, a_scales, b_scales);
}
}

} // namespace vllm
57 changes: 57 additions & 0 deletions csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#include <torch/all.h>
#include "cuda_utils.h"

template <typename Fp8Func, typename Int8Func, typename BlockwiseFunc>
void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias,
Fp8Func fp8_func, Int8Func int8_func,
BlockwiseFunc blockwise_func) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);

int M = a.size(0), N = b.size(1), K = a.size(1);

if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
(b_scales.numel() == 1 || b_scales.numel() == b.size(1))) {
// Standard per-tensor/per-token/per-channel scaling
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (a.dtype() == torch::kFloat8_e4m3fn) {
fp8_func(c, a, b, a_scales, b_scales, bias);
} else {
TORCH_CHECK(a.dtype() == torch::kInt8);
if constexpr (!std::is_same_v<Int8Func, std::nullptr_t>) {
int8_func(c, a, b, a_scales, b_scales, bias);
} else {
TORCH_CHECK(false, "Int8 not supported for this architecture");
}
}
} else {
using GroupShape = std::array<int64_t, 2>;
auto make_group_shape = [](torch::Tensor const& x,
torch::Tensor const& s) -> GroupShape {
TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D");
return {cuda_utils::ceil_div(x.size(0), s.size(0)),
cuda_utils::ceil_div(x.size(1), s.size(1))};
};

GroupShape a_scale_group_shape = make_group_shape(a, a_scales);
GroupShape b_scale_group_shape = make_group_shape(b, b_scales);

// 1x128 per-token group scales for activations
// 128x128 blockwise scales for weights
TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} &&
b_scale_group_shape == GroupShape{128, 128} &&
a.dtype() == torch::kFloat8_e4m3fn &&
b.dtype() == torch::kFloat8_e4m3fn),
"cutlass_scaled_mm only supports datatype float8_e4m3fn.\n"
"a_scale_group_shape must be [1, 128]. Got: [",
a_scale_group_shape[0], ", ", a_scale_group_shape[1],
"]\n"
"b_scale_group_shape must be [128, 128]. Got: [",
b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]");
TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm");
blockwise_func(c, a, b, a_scales, b_scales);
}
}
5 changes: 5 additions & 0 deletions csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,9 @@ void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias);

void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
} // namespace vllm
22 changes: 5 additions & 17 deletions csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
#include <cudaTypedefs.h>
#include "c3x/scaled_mm_helper.hpp"
#include "c3x/scaled_mm_kernels.hpp"

#include "cuda_utils.h"

/*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm100 (Blackwell).
Expand All @@ -15,20 +13,10 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);

int M = a.size(0), N = b.size(1), K = a.size(1);
TORCH_CHECK(
(a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
(b_scales.numel() == 1 || b_scales.numel() == b.size(1)),
"Currently, block scaled fp8 gemm is not implemented for Blackwell");

// Standard per-tensor/per-token/per-channel scaling
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn,
"Currently, only fp8 gemm is implemented for Blackwell");
vllm::cutlass_scaled_mm_sm100_fp8(c, a, b, a_scales, b_scales, bias);
dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias,
vllm::cutlass_scaled_mm_sm100_fp8,
nullptr, // int8 not supported on SM100
vllm::cutlass_scaled_mm_blockwise_sm100_fp8);
}

#endif
Loading