Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,33 +25,43 @@ using namespace cute;
template <class OutType, int ScaleGranularityM,
int ScaleGranularityN, int ScaleGranularityK,
class MmaTileShape, class ClusterShape,
class EpilogueScheduler, class MainloopScheduler>
class EpilogueScheduler, class MainloopScheduler,
bool swap_ab_ = false>
struct cutlass_3x_gemm_fp8_blockwise {
static constexpr bool swap_ab = swap_ab_;
using ElementAB = cutlass::float_e4m3_t;

using ElementA = ElementAB;
using LayoutA = cutlass::layout::RowMajor;
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;

using ElementB = ElementAB;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;

using ElementD = OutType;
using LayoutD = cutlass::layout::RowMajor;
using LayoutD_Transpose = typename cutlass::layout::LayoutTranspose<LayoutD>::type;
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;

using ElementC = void; // TODO: support bias
using LayoutC = LayoutD;
using LayoutC_Transpose = LayoutD_Transpose;
static constexpr int AlignmentC = AlignmentD;

using ElementAccumulator = float;
using ElementCompute = float;
using ElementBlockScale = float;

using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<
using ScaleConfig = conditional_t<swap_ab,
cutlass::detail::Sm90BlockwiseScaleConfig<
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
cute::GMMA::Major::MN, cute::GMMA::Major::K>;
cute::GMMA::Major::K, cute::GMMA::Major::MN>,
cutlass::detail::Sm90BlockwiseScaleConfig<
ScaleGranularityM, ScaleGranularityN, ScaleGranularityK,
cute::GMMA::Major::MN, cute::GMMA::Major::K>>;

using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
Expand All @@ -71,30 +81,46 @@ struct cutlass_3x_gemm_fp8_blockwise {
ElementAccumulator,
ElementCompute,
ElementC,
LayoutC,
conditional_t<swap_ab, LayoutC_Transpose, LayoutC>,
AlignmentC,
ElementD,
LayoutD,
conditional_t<swap_ab, LayoutD_Transpose, LayoutD>,
AlignmentD,
EpilogueScheduler,
DefaultOperation
>::CollectiveOp;

using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
ElementA,
cute::tuple<LayoutA, LayoutSFA>,
AlignmentA,
ElementB,
cute::tuple<LayoutB, LayoutSFB>,
AlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
MainloopScheduler
>::CollectiveOp;
using CollectiveMainloop = conditional_t<swap_ab,
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
ElementB,
cute::tuple<LayoutB_Transpose, LayoutSFA>,
AlignmentB,
ElementA,
cute::tuple<LayoutA_Transpose, LayoutSFB>,
AlignmentA,
ElementAccumulator,
MmaTileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
MainloopScheduler
>::CollectiveOp,
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag,
OperatorClass,
ElementA,
cute::tuple<LayoutA, LayoutSFA>,
AlignmentA,
ElementB,
cute::tuple<LayoutB, LayoutSFB>,
AlignmentB,
ElementAccumulator,
MmaTileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
MainloopScheduler
>::CollectiveOp>;

using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue>>;
Expand All @@ -107,6 +133,7 @@ void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Te
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) {
static constexpr bool swap_ab = Gemm::swap_ab;
using GemmKernel = typename Gemm::GemmKernel;
using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
Expand All @@ -122,8 +149,6 @@ void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Te

int32_t m = a.size(0), n = b.size(1), k = a.size(1);

STD_TORCH_CHECK(m % 4 == 0, "m must be divisible by 4");

StrideA a_stride;
StrideB b_stride;
StrideC c_stride;
Expand All @@ -132,28 +157,42 @@ void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Te
b_stride =
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
c_stride =
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1));
cutlass::make_cute_packed_stride(
StrideC{}, swap_ab ? cute::make_shape(n, m, 1)
: cute::make_shape(m, n, 1));

LayoutSFA layout_SFA =
ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1));
LayoutSFB layout_SFB =
ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));
LayoutSFA layout_SFA = swap_ab
? ScaleConfig::tile_atom_to_shape_SFA(make_shape(n, m, k, 1))
: ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1));
LayoutSFB layout_SFB = swap_ab
? ScaleConfig::tile_atom_to_shape_SFB(make_shape(n, m, k, 1))
: ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1));

auto a_ptr = static_cast<ElementAB const*>(a.data_ptr());
auto b_ptr = static_cast<ElementAB const*>(b.data_ptr());
auto a_scales_ptr = static_cast<ElementBlockScale const*>(a_scales.data_ptr());
auto b_scales_ptr = static_cast<ElementBlockScale const*>(b_scales.data_ptr());

typename GemmKernel::MainloopArguments mainloop_args{};
mainloop_args.ptr_A = a_ptr;
mainloop_args.dA = a_stride;
mainloop_args.ptr_B = b_ptr;
mainloop_args.dB = b_stride;
mainloop_args.ptr_SFA = a_scales_ptr;
mainloop_args.layout_SFA = layout_SFA;
mainloop_args.ptr_SFB = b_scales_ptr;
mainloop_args.layout_SFB = layout_SFB;
auto prob_shape = cute::make_shape(m, n, k, 1);
if (swap_ab) {
mainloop_args.ptr_A = b_ptr;
mainloop_args.dA = b_stride;
mainloop_args.ptr_B = a_ptr;
mainloop_args.dB = a_stride;
mainloop_args.ptr_SFA = b_scales_ptr;
mainloop_args.ptr_SFB = a_scales_ptr;
} else {
mainloop_args.ptr_A = a_ptr;
mainloop_args.dA = a_stride;
mainloop_args.ptr_B = b_ptr;
mainloop_args.dB = b_stride;
mainloop_args.ptr_SFA = a_scales_ptr;
mainloop_args.ptr_SFB = b_scales_ptr;
}
auto prob_shape = swap_ab ? cute::make_shape(n, m, k, 1)
: cute::make_shape(m, n, k, 1);

auto c_ptr = static_cast<ElementD*>(out.data_ptr());
typename GemmKernel::EpilogueArguments epilogue_args{
Expand All @@ -168,12 +207,21 @@ void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::stable::Tensor& out,
torch::stable::Tensor const& b,
torch::stable::Tensor const& a_scales,
torch::stable::Tensor const& b_scales) {
// TODO: better heuristics
bool swap_ab = (a.size(0) % 4) != 0;
if (!swap_ab) {
cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
OutType, 1, 128, 128, Shape<_128, _128, _128>,
Shape<_1, _2, _1>, cutlass::epilogue::TmaWarpSpecializedCooperative,
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum>>(
out, a, b, a_scales, b_scales);
return;
}

cutlass_gemm_caller_blockwise<cutlass_3x_gemm_fp8_blockwise<
OutType, 1, 128, 128, Shape<_128, _128, _128>,
Shape<_1, _2, _1>, cutlass::epilogue::TmaWarpSpecializedCooperative,
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum>>(
out, a, b, a_scales, b_scales);
OutType, 128, 1, 128, Shape<_128, _16, _128>,
Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized,
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8BlockScaledAccum,
true>>(out, a, b, a_scales, b_scales);
}

} // namespace vllm
2 changes: 0 additions & 2 deletions tests/kernels/quantization/test_cutlass_scaled_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,6 @@ def test_cutlass_fp8_blockwise_scale_gemm(
return
if m % a_scale_group_shape[0] != 0 or k % a_scale_group_shape[1] != 0:
return
if m % 4 != 0 and current_platform.has_device_capability(100):
return
cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape, use_bias)


Expand Down
118 changes: 0 additions & 118 deletions vllm/model_executor/kernels/linear/scaled_mm/cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op

from .BlockScaledMMLinearKernel import Fp8BlockScaledMMLinearKernel
from .ScaledMMLinearKernel import (
Expand Down Expand Up @@ -268,15 +267,13 @@ class CutlassFp8BlockScaledMMKernel(Fp8BlockScaledMMLinearKernel):
def __init__(self, config: FP8ScaledMMLinearLayerConfig) -> None:
super().__init__(config)
act_scale_descriptor = config.activation_quant_key.scale
self.weight_group_shape = config.weight_quant_key.scale.group_shape
self.quant_fp8 = QuantFP8(
static=act_scale_descriptor.static,
group_shape=act_scale_descriptor.group_shape,
num_token_padding=self.get_output_padding(),
use_ue8m0=False,
column_major_scales=True,
)
self.is_hopper = current_platform.is_device_capability(90)

@classmethod
def is_supported(cls, compute_capability=None):
Expand Down Expand Up @@ -311,16 +308,6 @@ def apply_block_scaled_mm(
Bs: torch.Tensor,
) -> torch.Tensor:
out_dtype = self.config.out_dtype
if self.is_hopper:
return torch.ops.vllm.dynamic_padded_cutlass(
A,
B,
As,
Bs,
list(self.weight_group_shape),
out_dtype,
)

return ops.cutlass_scaled_mm(
A,
B.T,
Expand All @@ -345,108 +332,3 @@ def cutlass_scaled_mm(
scale_a=As,
scale_b=Bs.T,
)


def _padded_cutlass(
qx: torch.Tensor,
weight: torch.Tensor,
x_scale: torch.Tensor,
weight_scale: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype,
) -> torch.Tensor:
pad_multiple = 4
dim = qx.shape[0]
padded = (
dim if dim % pad_multiple == 0 else dim + pad_multiple - (dim % pad_multiple)
)

has_pad = padded > dim

if has_pad:
padded_shape = [padded, *qx.shape[1:]]
padded_qx = torch.zeros(padded_shape, device=qx.device, dtype=qx.dtype)
padded_qx[0 : qx.shape[0], ...].copy_(qx)

padded_x_scale_shape = [*x_scale.shape[1:], padded]
padded_x_scale = torch.ones(
padded_x_scale_shape, device=x_scale.device, dtype=x_scale.dtype
).permute(-1, -2)
padded_x_scale[0 : x_scale.shape[0], ...].copy_(x_scale)

output = cutlass_scaled_mm(
padded_qx, weight, padded_x_scale, weight_scale, block_size, output_dtype
)
return output[0 : qx.shape[0], ...]
else:
return cutlass_scaled_mm(
qx, weight, x_scale, weight_scale, block_size, output_dtype
)


def _padded_cutlass_fake(
qx: torch.Tensor,
weight: torch.Tensor,
x_scale: torch.Tensor,
weight_scale: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype,
) -> torch.Tensor:
return torch.empty(
(qx.size(0), weight.size(0)), dtype=output_dtype, device=qx.device
)


def _dynamic_padded_cutlass(
qx: torch.Tensor,
weight: torch.Tensor,
x_scale: torch.Tensor,
weight_scale: torch.Tensor,
block_size: list[int],
output_dtype: torch.dtype,
) -> torch.Tensor:
def run_padded(
qx: torch.Tensor,
weight: torch.Tensor,
x_scale: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:
return _padded_cutlass(
qx, weight, x_scale, weight_scale, block_size, output_dtype
)

def run_direct(
qx: torch.Tensor,
weight: torch.Tensor,
x_scale: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:
return cutlass_scaled_mm(
qx, weight, x_scale, weight_scale, block_size, output_dtype
)

if torch.compiler.is_compiling():
return torch.cond(
qx.shape[0] % 4 != 0,
run_padded,
run_direct,
(qx, weight, x_scale, weight_scale),
)

if qx.shape[0] % 4 != 0:
return run_padded(qx, weight, x_scale, weight_scale)

return run_direct(qx, weight, x_scale, weight_scale)


direct_register_custom_op(
"padded_cutlass",
_padded_cutlass,
fake_impl=_padded_cutlass_fake,
)

direct_register_custom_op(
"dynamic_padded_cutlass",
_dynamic_padded_cutlass,
fake_impl=_padded_cutlass_fake,
)
Loading