From d001b23baf07724fbe8d7cf5c7578d9d99d4c8a7 Mon Sep 17 00:00:00 2001 From: shuw Date: Mon, 10 Mar 2025 20:00:19 +0000 Subject: [PATCH 1/7] Support cutlass w8a8 blockwise matmul for blackwell Signed-off-by: Shu Wang --- CMakeLists.txt | 1 + csrc/cutlass_extensions/common.hpp | 12 +- .../c3x/scaled_mm_blockwise_sm100_fp8.cu | 60 ++++++ ...scaled_mm_blockwise_sm100_fp8_dispatch.cuh | 185 ++++++++++++++++++ .../cutlass_w8a8/c3x/scaled_mm_kernels.hpp | 5 + .../cutlass_w8a8/scaled_mm_c3x_sm100.cu | 46 ++++- 6 files changed, 298 insertions(+), 11 deletions(-) create mode 100644 csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu create mode 100644 csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh diff --git a/CMakeLists.txt b/CMakeLists.txt index 4b3bfe0af7f5..cad9f4428653 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -418,6 +418,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu" "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu" + "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu" ) set_gencode_flags_for_srcs( SRCS "${SRCS}" diff --git a/csrc/cutlass_extensions/common.hpp b/csrc/cutlass_extensions/common.hpp index dbe0e30f5cbf..bbe623d1a76f 100644 --- a/csrc/cutlass_extensions/common.hpp +++ b/csrc/cutlass_extensions/common.hpp @@ -1,4 +1,4 @@ -#pragma once +;#pragma once #include "cutlass/cutlass.h" #include @@ -59,3 +59,13 @@ struct enable_sm90_only : Kernel { #endif } }; + +template +struct enable_sm100_or_later : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 1000 + Kernel::operator()(std::forward(args)...); +#endif + } +}; diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu new file mode 100644 index 000000000000..a8852e5cc269 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu @@ -0,0 +1,60 @@ +#include "scaled_mm_kernels.hpp" +#include "scaled_mm_blockwise_sm100_fp8_dispatch.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" + +namespace vllm { + +// Pads to a multiple of `alignment` rows. +inline torch::Tensor pad_tensor(const torch::Tensor& tensor, + int64_t alignment = 4, + bool is_column_major = false) { + int64_t rows = tensor.size(0); + int64_t cols = tensor.size(1); + int64_t pad_rows = (alignment - (rows % alignment)) % alignment; + + if (pad_rows == 0) { + return tensor; + } + + torch::Tensor padding = torch::zeros({pad_rows, cols}, tensor.options()); + torch::Tensor tensor_padded = torch::cat({tensor, padding}, 0); + + // Ensure column-major layout + if (is_column_major) { + return tensor_padded.t().contiguous().t(); + } + return tensor_padded; +} + +void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + int64_t original_rows = a.size(0); + + torch::Tensor a_padded = pad_tensor(a, /*alignment=*/4); + torch::Tensor a_scales_padded = + pad_tensor(a_scales, /*alignment=*/4, /*col_major=*/true); + torch::Tensor out_padded; + if (a_padded.size(0) == a.size(0)) { + out_padded = out; + } else { + out_padded = + torch::zeros({a_padded.size(0), b.size(1)}, out.options()).contiguous(); + } + + if (out.dtype() == torch::kBFloat16) { + cutlass_gemm_blockwise_sm100_fp8_dispatch( + out_padded, a_padded, b, a_scales_padded, b_scales); + } else { + TORCH_CHECK(out.dtype() == torch::kFloat16); + cutlass_gemm_blockwise_sm100_fp8_dispatch( + out_padded, a_padded, b, a_scales_padded, b_scales); + } + if (a_padded.size(0) != a.size(0)) { + out.copy_(out_padded.slice(0, 0, original_rows)); + } +} + +} // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh new file mode 100644 index 000000000000..d13325ad95d2 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh @@ -0,0 +1,185 @@ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass_extensions/gemm/dispatch_policy.hpp" +#include "cutlass_extensions/gemm/collective/collective_builder.hpp" + +#include "cutlass_gemm_caller.cuh" + +namespace vllm { + +using namespace cute; + +template > +struct cutlass_3x_gemm_fp8_blockwise { + using TileSizeM = Int; + + using ElementAB = cutlass::float_e4m3_t; + + using ElementA = ElementAB; + using LayoutA = cutlass::layout::RowMajor; + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + + using ElementB = ElementAB; + using LayoutB = cutlass::layout::ColumnMajor; + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + + using ElementC = void; + using ElementD = OutType; + using LayoutD = cutlass::layout::RowMajor; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + // using StrideC = StrideD; + using LayoutC = LayoutD; + static constexpr int AlignmentC = AlignmentD; + + using ElementAccumulator = float; + using ElementCompute = float; + using ElementBlockScale = float; + + // MMA and Cluster Tile Shapes + // Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster + // Shape %2 == 0 using MmaTileShape_MNK = Shape<_128,_128,_128>; + static constexpr int ScaleMsPerTile = size<0>(ScalesPerTile{}); + static constexpr int ScaleGranularityM = + size<0>(MmaTileShape{}) / ScaleMsPerTile; + static constexpr int ScaleGranularityN = + size<1>(MmaTileShape{}) / size<1>(ScalesPerTile{}); + static constexpr int ScaleGranularityK = + size<2>(MmaTileShape{}) / size<2>(ScalesPerTile{}); + + // Shape of the threadblocks in a cluster + using ClusterShape_MNK = Shape<_1, _1, _1>; + + // using ScaleConfig = + // decltype(cutlass::detail::sm100_trivial_blockwise_scale_config(MmaTileShape_MNK{})); + // static constexpr int ScaleGranularityM = size<0>(MmaTileShape{}) / + // ScaleMsPerTile; + using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig< + ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, + cute::UMMA::Major::MN, cute::UMMA::Major::K>; + + using LayoutSFA = + decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix + // operand + using LayoutSFB = + decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix + // operand + using ArchTag = cutlass::arch::Sm100; + using OperatorClass = cutlass::arch::OpClassTensorOp; + + using AtomThrShape = Shape<_1, _1, _1>; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, PerSmTileShape, ClusterShape, + EpilogueTileShape, ElementAccumulator, ElementCompute, ElementC, + LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, + cutlass::epilogue::TmaWarpSpecialized1Sm>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementA, cute::tuple, + AlignmentA, ElementB, cute::tuple, AlignmentB, + ElementAccumulator, MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>:: + CollectiveOp; + + using KernelType = enable_sm100_or_later, CollectiveMainloop, CollectiveEpilogue, + cutlass::gemm::PersistentScheduler>>; + + struct GemmKernel : public KernelType {}; +}; + +template +void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + using GemmKernel = typename Gemm::GemmKernel; + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideD = typename Gemm::GemmKernel::StrideD; + using StrideC = typename Gemm::GemmKernel::StrideC; + using LayoutSFA = typename Gemm::LayoutSFA; + using LayoutSFB = typename Gemm::LayoutSFB; + using ScaleConfig = typename Gemm::ScaleConfig; + + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; + + int32_t m = a.size(0), n = b.size(1), k = a.size(1); + auto prob_shape = cute::make_shape(m, n, k, 1); + + StrideA a_stride; + StrideB b_stride; + StrideC c_stride; + a_stride = + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); + b_stride = + cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1)); + c_stride = + cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1)); + + LayoutSFA layout_SFA = + ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1)); + LayoutSFB layout_SFB = + ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1)); + + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + auto a_scales_ptr = static_cast(a_scales.data_ptr()); + auto b_scales_ptr = static_cast(b_scales.data_ptr()); + + typename GemmKernel::MainloopArguments mainloop_args{ + a_ptr, a_stride, b_ptr, b_stride, + a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB}; + + auto c_ptr = static_cast(out.data_ptr()); + typename GemmKernel::EpilogueArguments epilogue_args{ + {}, c_ptr, c_stride, c_ptr, c_stride}; + epilogue_args.thread.alpha = 1.0f; + c3x::cutlass_gemm_caller(a.device(), prob_shape, mainloop_args, + epilogue_args); +} + +template +void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + auto m = a.size(0); + auto k = a.size(1); + auto n = b.size(1); + + // Define tile shapes based on the value of m + if (m <= 128) { + cutlass_gemm_caller_blockwise, Shape<_64, _128, _128>, + Shape<_64, _64>, Shape<_64, _1, _1>>>(out, a, b, a_scales, b_scales); + } else { + cutlass_gemm_caller_blockwise, Shape<_128, _128, _128>, + Shape<_128, _64>, Shape<_128, _1, _1>>>(out, a, b, a_scales, b_scales); + } +} + +} // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp index 85272804774d..c1242fdb39da 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp @@ -36,4 +36,9 @@ void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b_scales, std::optional const& bias); +void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales); } // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu index 459eb1bb76eb..f989512fe8e0 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu @@ -19,16 +19,42 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a, TORCH_CHECK(b_scales.dtype() == torch::kFloat32); int M = a.size(0), N = b.size(1), K = a.size(1); - TORCH_CHECK( - (a_scales.numel() == 1 || a_scales.numel() == a.size(0)) && - (b_scales.numel() == 1 || b_scales.numel() == b.size(1)), - "Currently, block scaled fp8 gemm is not implemented for Blackwell"); - - // Standard per-tensor/per-token/per-channel scaling - TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); - TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn, - "Currently, only fp8 gemm is implemented for Blackwell"); - vllm::cutlass_scaled_mm_sm100_fp8(c, a, b, a_scales, b_scales, bias); + + if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) && + (b_scales.numel() == 1 || b_scales.numel() == b.size(1))) { + // Standard per-tensor/per-token/per-channel scaling + TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn, + "Currently, only fp8 gemm is implemented for Blackwell"); + vllm::cutlass_scaled_mm_sm100_fp8(c, a, b, a_scales, b_scales, bias); + } else { + using GroupShape = std::array; + auto make_group_shape = [](torch::Tensor const& x, + torch::Tensor const& s) -> GroupShape { + TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D"); + return {cuda_utils::ceil_div(x.size(0), s.size(0)), + cuda_utils::ceil_div(x.size(1), s.size(1))}; + }; + + GroupShape a_scale_group_shape = make_group_shape(a, a_scales); + GroupShape b_scale_group_shape = make_group_shape(b, b_scales); + + // 1x128 per-token group scales for activations + // 128x128 blockwise scales for weights + TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} && + b_scale_group_shape == GroupShape{128, 128} && + a.dtype() == torch::kFloat8_e4m3fn && + b.dtype() == torch::kFloat8_e4m3fn), + "cutlass_scaled_mm only supports datatype float8_e4m3fn.\n" + "a_scale_group_shape must be [1, 128]. Got: [", + a_scale_group_shape[0], ", ", a_scale_group_shape[1], + "]\n" + "b_scale_group_shape must be [128, 128]. Got: [", + b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]"); + TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm"); + + vllm::cutlass_scaled_mm_blockwise_sm100_fp8(c, a, b, a_scales, b_scales); + } } #endif From b45e6bb22b85316285f5293e8f59573d581b1dd3 Mon Sep 17 00:00:00 2001 From: Shu Wang Date: Thu, 13 Mar 2025 16:43:45 -0500 Subject: [PATCH 2/7] Update CUDA_VERSION requirement. Signed-off-by: Shu Wang --- csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu | 2 ++ 1 file changed, 2 insertions(+) diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 54b63894e4cb..ddcc48cccab1 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -110,6 +110,8 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) { #if defined CUDA_VERSION if (cuda_device_capability >= 90 && cuda_device_capability < 100) { return CUDA_VERSION >= 12000; + } else if (cuda_device_capability >= 100) { + return CUDA_VERSION >= 12080; } #endif From 5100a33a1507cff9b908d72b73fe6c194b70e968 Mon Sep 17 00:00:00 2001 From: shuw Date: Fri, 28 Mar 2025 18:42:35 +0000 Subject: [PATCH 3/7] Improve after review 1 Signed-off-by: Shu Wang --- csrc/cutlass_extensions/common.hpp | 2 +- .../c3x/scaled_mm_blockwise_sm100_fp8.cu | 47 +++------------ ...scaled_mm_blockwise_sm100_fp8_dispatch.cuh | 56 +++++++++++------- .../cutlass_w8a8/c3x/scaled_mm_helper.hpp | 57 +++++++++++++++++++ .../cutlass_w8a8/scaled_mm_c3x_sm100.cu | 48 ++-------------- .../cutlass_w8a8/scaled_mm_c3x_sm90.cu | 51 ++--------------- .../quantization/test_cutlass_scaled_mm.py | 4 +- .../layers/quantization/utils/fp8_utils.py | 11 ++++ 8 files changed, 125 insertions(+), 151 deletions(-) create mode 100644 csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp diff --git a/csrc/cutlass_extensions/common.hpp b/csrc/cutlass_extensions/common.hpp index bbe623d1a76f..7d7dcfe76c2f 100644 --- a/csrc/cutlass_extensions/common.hpp +++ b/csrc/cutlass_extensions/common.hpp @@ -1,4 +1,4 @@ -;#pragma once +#pragma once #include "cutlass/cutlass.h" #include diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu index a8852e5cc269..84492553c02f 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu @@ -4,56 +4,23 @@ namespace vllm { -// Pads to a multiple of `alignment` rows. -inline torch::Tensor pad_tensor(const torch::Tensor& tensor, - int64_t alignment = 4, - bool is_column_major = false) { - int64_t rows = tensor.size(0); - int64_t cols = tensor.size(1); - int64_t pad_rows = (alignment - (rows % alignment)) % alignment; - - if (pad_rows == 0) { - return tensor; - } - - torch::Tensor padding = torch::zeros({pad_rows, cols}, tensor.options()); - torch::Tensor tensor_padded = torch::cat({tensor, padding}, 0); - - // Ensure column-major layout - if (is_column_major) { - return tensor_padded.t().contiguous().t(); - } - return tensor_padded; -} - void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales) { - int64_t original_rows = a.size(0); - - torch::Tensor a_padded = pad_tensor(a, /*alignment=*/4); - torch::Tensor a_scales_padded = - pad_tensor(a_scales, /*alignment=*/4, /*col_major=*/true); - torch::Tensor out_padded; - if (a_padded.size(0) == a.size(0)) { - out_padded = out; - } else { - out_padded = - torch::zeros({a_padded.size(0), b.size(1)}, out.options()).contiguous(); - } - + TORCH_CHECK( + a.size(0) % 4 == 0, + "Input tensor must have a number of rows that is a multiple of 4. ", + "but got: ", a.size(0), " rows."); if (out.dtype() == torch::kBFloat16) { cutlass_gemm_blockwise_sm100_fp8_dispatch( - out_padded, a_padded, b, a_scales_padded, b_scales); + out, a, b, a_scales, b_scales); + } else { TORCH_CHECK(out.dtype() == torch::kFloat16); cutlass_gemm_blockwise_sm100_fp8_dispatch( - out_padded, a_padded, b, a_scales_padded, b_scales); - } - if (a_padded.size(0) != a.size(0)) { - out.copy_(out_padded.slice(0, 0, original_rows)); + out, a, b, a_scales, b_scales); } } diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh index d13325ad95d2..5ea643643406 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh @@ -65,10 +65,6 @@ struct cutlass_3x_gemm_fp8_blockwise { // Shape of the threadblocks in a cluster using ClusterShape_MNK = Shape<_1, _1, _1>; - // using ScaleConfig = - // decltype(cutlass::detail::sm100_trivial_blockwise_scale_config(MmaTileShape_MNK{})); - // static constexpr int ScaleGranularityM = size<0>(MmaTileShape{}) / - // ScaleMsPerTile; using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig< ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, cute::UMMA::Major::MN, cute::UMMA::Major::K>; @@ -84,22 +80,42 @@ struct cutlass_3x_gemm_fp8_blockwise { using AtomThrShape = Shape<_1, _1, _1>; - using CollectiveEpilogue = - typename cutlass::epilogue::collective::CollectiveBuilder< - ArchTag, OperatorClass, PerSmTileShape, ClusterShape, - EpilogueTileShape, ElementAccumulator, ElementCompute, ElementC, - LayoutC, AlignmentC, ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecialized1Sm>::CollectiveOp; - - using CollectiveMainloop = - typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, OperatorClass, ElementA, cute::tuple, - AlignmentA, ElementB, cute::tuple, AlignmentB, - ElementAccumulator, MmaTileShape, ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout( - sizeof(typename CollectiveEpilogue::SharedStorage))>, - cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>:: - CollectiveOp; + // clang-format off + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + PerSmTileShape, + ClusterShape, + EpilogueTileShape, + ElementAccumulator, + ElementCompute, + ElementC, + LayoutC, + AlignmentC, + ElementD, + LayoutD, + AlignmentD, + cutlass::epilogue::TmaWarpSpecialized1Sm + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementA, + cute::tuple, + AlignmentA, + ElementB, + cute::tuple, + AlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage)) + >, + cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100 + >::CollectiveOp; + // clang-format on using KernelType = enable_sm100_or_later, CollectiveMainloop, CollectiveEpilogue, diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp new file mode 100644 index 000000000000..b589a479081e --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp @@ -0,0 +1,57 @@ +#include +#include "cuda_utils.h" + +template +void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a, + torch::Tensor const& b, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + std::optional const& bias, + Fp8Func fp8_func, Int8Func int8_func, + BlockwiseFunc blockwise_func) { + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + + int M = a.size(0), N = b.size(1), K = a.size(1); + + if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) && + (b_scales.numel() == 1 || b_scales.numel() == b.size(1))) { + // Standard per-tensor/per-token/per-channel scaling + TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); + if (a.dtype() == torch::kFloat8_e4m3fn) { + fp8_func(c, a, b, a_scales, b_scales, bias); + } else { + TORCH_CHECK(a.dtype() == torch::kInt8); + if constexpr (!std::is_same_v) { + int8_func(c, a, b, a_scales, b_scales, bias); + } else { + TORCH_CHECK(false, "Int8 not supported for this architecture"); + } + } + } else { + using GroupShape = std::array; + auto make_group_shape = [](torch::Tensor const& x, + torch::Tensor const& s) -> GroupShape { + TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D"); + return {cuda_utils::ceil_div(x.size(0), s.size(0)), + cuda_utils::ceil_div(x.size(1), s.size(1))}; + }; + + GroupShape a_scale_group_shape = make_group_shape(a, a_scales); + GroupShape b_scale_group_shape = make_group_shape(b, b_scales); + + // 1x128 per-token group scales for activations + // 128x128 blockwise scales for weights + TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} && + b_scale_group_shape == GroupShape{128, 128} && + a.dtype() == torch::kFloat8_e4m3fn && + b.dtype() == torch::kFloat8_e4m3fn), + "cutlass_scaled_mm only supports datatype float8_e4m3fn.\n" + "a_scale_group_shape must be [1, 128]. Got: [", + a_scale_group_shape[0], ", ", a_scale_group_shape[1], + "]\n" + "b_scale_group_shape must be [128, 128]. Got: [", + b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]"); + TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm"); + blockwise_func(c, a, b, a_scales, b_scales); + } +} diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu index f989512fe8e0..0cbd5305e3c2 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu @@ -1,8 +1,6 @@ -#include +#include "c3x/scaled_mm_helper.hpp" #include "c3x/scaled_mm_kernels.hpp" -#include "cuda_utils.h" - /* This file defines quantized GEMM operations using the CUTLASS 3.x API, for NVIDIA GPUs with sm100 (Blackwell). @@ -15,46 +13,10 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& a_scales, torch::Tensor const& b_scales, std::optional const& bias) { - TORCH_CHECK(a_scales.dtype() == torch::kFloat32); - TORCH_CHECK(b_scales.dtype() == torch::kFloat32); - - int M = a.size(0), N = b.size(1), K = a.size(1); - - if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) && - (b_scales.numel() == 1 || b_scales.numel() == b.size(1))) { - // Standard per-tensor/per-token/per-channel scaling - TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); - TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn, - "Currently, only fp8 gemm is implemented for Blackwell"); - vllm::cutlass_scaled_mm_sm100_fp8(c, a, b, a_scales, b_scales, bias); - } else { - using GroupShape = std::array; - auto make_group_shape = [](torch::Tensor const& x, - torch::Tensor const& s) -> GroupShape { - TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D"); - return {cuda_utils::ceil_div(x.size(0), s.size(0)), - cuda_utils::ceil_div(x.size(1), s.size(1))}; - }; - - GroupShape a_scale_group_shape = make_group_shape(a, a_scales); - GroupShape b_scale_group_shape = make_group_shape(b, b_scales); - - // 1x128 per-token group scales for activations - // 128x128 blockwise scales for weights - TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} && - b_scale_group_shape == GroupShape{128, 128} && - a.dtype() == torch::kFloat8_e4m3fn && - b.dtype() == torch::kFloat8_e4m3fn), - "cutlass_scaled_mm only supports datatype float8_e4m3fn.\n" - "a_scale_group_shape must be [1, 128]. Got: [", - a_scale_group_shape[0], ", ", a_scale_group_shape[1], - "]\n" - "b_scale_group_shape must be [128, 128]. Got: [", - b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]"); - TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm"); - - vllm::cutlass_scaled_mm_blockwise_sm100_fp8(c, a, b, a_scales, b_scales); - } + dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias, + vllm::cutlass_scaled_mm_sm100_fp8, + nullptr, // int8 not supported on SM100 + vllm::cutlass_scaled_mm_blockwise_sm100_fp8); } #endif diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu index bcb91040d5e2..211302171f07 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu @@ -1,8 +1,6 @@ -#include +#include "c3x/scaled_mm_helper.hpp" #include "c3x/scaled_mm_kernels.hpp" -#include "cuda_utils.h" - /* This file defines quantized GEMM operations using the CUTLASS 3.x API, for NVIDIA GPUs with sm90a (Hopper). @@ -15,49 +13,10 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& a_scales, torch::Tensor const& b_scales, std::optional const& bias) { - TORCH_CHECK(a_scales.dtype() == torch::kFloat32); - TORCH_CHECK(b_scales.dtype() == torch::kFloat32); - - int M = a.size(0), N = b.size(1), K = a.size(1); - - if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) && - (b_scales.numel() == 1 || b_scales.numel() == b.size(1))) { - // Standard per-tensor/per-token/per-channel scaling - TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); - if (a.dtype() == torch::kFloat8_e4m3fn) { - vllm::cutlass_scaled_mm_sm90_fp8(c, a, b, a_scales, b_scales, bias); - } else { - TORCH_CHECK(a.dtype() == torch::kInt8); - vllm::cutlass_scaled_mm_sm90_int8(c, a, b, a_scales, b_scales, bias); - } - } else { - using GroupShape = std::array; - auto make_group_shape = [](torch::Tensor const& x, - torch::Tensor const& s) -> GroupShape { - TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D"); - return {cuda_utils::ceil_div(x.size(0), s.size(0)), - cuda_utils::ceil_div(x.size(1), s.size(1))}; - }; - - GroupShape a_scale_group_shape = make_group_shape(a, a_scales); - GroupShape b_scale_group_shape = make_group_shape(b, b_scales); - - // 1x128 per-token group scales for activations - // 128x128 blockwise scales for weights - TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} && - b_scale_group_shape == GroupShape{128, 128} && - a.dtype() == torch::kFloat8_e4m3fn && - b.dtype() == torch::kFloat8_e4m3fn), - "cutlass_scaled_mm only supports datatype float8_e4m3fn.\n" - "a_scale_group_shape must be [1, 128]. Got: [", - a_scale_group_shape[0], ", ", a_scale_group_shape[1], - "]\n" - "b_scale_group_shape must be [128, 128]. Got: [", - b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]"); - TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm"); - - vllm::cutlass_scaled_mm_blockwise_sm90_fp8(c, a, b, a_scales, b_scales); - } + dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias, + vllm::cutlass_scaled_mm_sm90_fp8, + vllm::cutlass_scaled_mm_sm90_int8, + vllm::cutlass_scaled_mm_blockwise_sm90_fp8); } void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a, diff --git a/tests/kernels/quantization/test_cutlass_scaled_mm.py b/tests/kernels/quantization/test_cutlass_scaled_mm.py index 8084d9bf2c2d..633addd421f4 100644 --- a/tests/kernels/quantization/test_cutlass_scaled_mm.py +++ b/tests/kernels/quantization/test_cutlass_scaled_mm.py @@ -95,7 +95,7 @@ def cutlass_fp8_gemm_helper(m: int, out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) - torch.testing.assert_close(out, baseline, rtol=1e-2, atol=5e-2) + torch.testing.assert_close(out, baseline, rtol=1e-2, atol=1.5e-1) opcheck(torch.ops._C.cutlass_scaled_mm, (out, a, b, scale_a, scale_b, bias)) @@ -161,6 +161,8 @@ def test_cutlass_fp8_blockwise_scale_gemm(m: int, n: int, k: int, return if m % a_scale_group_shape[0] != 0 or k % a_scale_group_shape[1] != 0: return + if m % 4 != 0 and current_platform.has_device_capability(100): + return cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape, use_bias) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 064cbb8cf52d..9475de9f0b80 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -57,6 +57,15 @@ def apply_w8a8_block_fp8_linear( or br not in (1, weight.shape[0])): shape_supported_by_cutlass = False if cutlass_block_fp8_supported and shape_supported_by_cutlass: + rows, cols = input_2d.shape + should_pad = current_platform.has_device_capability( + 100) and rows % 4 != 0 + if should_pad: + padding = torch.zeros((4 - (rows % 4), cols), + dtype=input_2d.dtype, + device=input_2d.device) + input_2d = torch.cat([input_2d, padding], dim=0).contiguous() + q_input, x_scale = per_token_group_quant_fp8(input_2d, block_size[1], column_major_scales=True) @@ -65,6 +74,8 @@ def apply_w8a8_block_fp8_linear( out_dtype=input.dtype, scale_a=x_scale, scale_b=weight_scale.T) + if should_pad: + output = output[:rows, :] else: q_input, x_scale = per_token_group_quant_fp8(input_2d, block_size[1], From 24a528e13101b7a1237ddb1a8d15cd33b6b4ddb3 Mon Sep 17 00:00:00 2001 From: Shu Wang Date: Fri, 4 Apr 2025 15:53:38 -0500 Subject: [PATCH 4/7] Correct forward compatibility. Signed-off-by: Shu Wang --- csrc/cutlass_extensions/common.hpp | 4 ++-- .../c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/csrc/cutlass_extensions/common.hpp b/csrc/cutlass_extensions/common.hpp index 7d7dcfe76c2f..0877da52435e 100644 --- a/csrc/cutlass_extensions/common.hpp +++ b/csrc/cutlass_extensions/common.hpp @@ -61,10 +61,10 @@ struct enable_sm90_only : Kernel { }; template -struct enable_sm100_or_later : Kernel { +struct enable_sm100_only : Kernel { template CUTLASS_DEVICE void operator()(Args&&... args) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 1000 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 1000 Kernel::operator()(std::forward(args)...); #endif } diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh index 5ea643643406..0c1f910b64ee 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh @@ -117,7 +117,7 @@ struct cutlass_3x_gemm_fp8_blockwise { >::CollectiveOp; // clang-format on - using KernelType = enable_sm100_or_later, CollectiveMainloop, CollectiveEpilogue, cutlass::gemm::PersistentScheduler>>; From c030b940efd33fc84d3e7c671605d9636475af8b Mon Sep 17 00:00:00 2001 From: shuw Date: Tue, 22 Apr 2025 20:55:17 +0000 Subject: [PATCH 5/7] Branching with 1/2SMs Signed-off-by: Shu Wang --- ...scaled_mm_blockwise_sm100_fp8_dispatch.cuh | 70 ++++++++++--------- 1 file changed, 37 insertions(+), 33 deletions(-) diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh index 0c1f910b64ee..ef324364c6d5 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh @@ -22,12 +22,10 @@ namespace vllm { using namespace cute; -template > +template struct cutlass_3x_gemm_fp8_blockwise { - using TileSizeM = Int; - using ElementAB = cutlass::float_e4m3_t; using ElementA = ElementAB; @@ -43,7 +41,6 @@ struct cutlass_3x_gemm_fp8_blockwise { using LayoutD = cutlass::layout::RowMajor; static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; - // using StrideC = StrideD; using LayoutC = LayoutD; static constexpr int AlignmentC = AlignmentD; @@ -63,30 +60,27 @@ struct cutlass_3x_gemm_fp8_blockwise { size<2>(MmaTileShape{}) / size<2>(ScalesPerTile{}); // Shape of the threadblocks in a cluster - using ClusterShape_MNK = Shape<_1, _1, _1>; + using ClusterShape_MNK = ClusterShape; using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig< ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, cute::UMMA::Major::MN, cute::UMMA::Major::K>; + using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); + using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); - using LayoutSFA = - decltype(ScaleConfig::deduce_layoutSFA()); // Layout type for SFA matrix - // operand - using LayoutSFB = - decltype(ScaleConfig::deduce_layoutSFB()); // Layout type for SFB matrix - // operand using ArchTag = cutlass::arch::Sm100; using OperatorClass = cutlass::arch::OpClassTensorOp; - using AtomThrShape = Shape<_1, _1, _1>; - + static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + using ElementScalar = float; // clang-format off + using DefaultOperation = cutlass::epilogue::fusion::LinearCombination; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, - PerSmTileShape, + MmaTileShape, ClusterShape, - EpilogueTileShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementCompute, ElementC, @@ -95,9 +89,11 @@ struct cutlass_3x_gemm_fp8_blockwise { ElementD, LayoutD, AlignmentD, - cutlass::epilogue::TmaWarpSpecialized1Sm + EpilogueScheduler, + DefaultOperation >::CollectiveOp; - + + using StageCountType = cutlass::gemm::collective::StageCountAuto; using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< ArchTag, OperatorClass, @@ -110,16 +106,14 @@ struct cutlass_3x_gemm_fp8_blockwise { ElementAccumulator, MmaTileShape, ClusterShape, - cutlass::gemm::collective::StageCountAutoCarveout< - static_cast(sizeof(typename CollectiveEpilogue::SharedStorage)) - >, - cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100 + + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopScheduler >::CollectiveOp; // clang-format on using KernelType = enable_sm100_only, CollectiveMainloop, CollectiveEpilogue, - cutlass::gemm::PersistentScheduler>>; + Shape, CollectiveMainloop, CollectiveEpilogue>>; struct GemmKernel : public KernelType {}; }; @@ -171,7 +165,6 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, auto c_ptr = static_cast(out.data_ptr()); typename GemmKernel::EpilogueArguments epilogue_args{ {}, c_ptr, c_stride, c_ptr, c_stride}; - epilogue_args.thread.alpha = 1.0f; c3x::cutlass_gemm_caller(a.device(), prob_shape, mainloop_args, epilogue_args); } @@ -185,16 +178,27 @@ void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out, auto m = a.size(0); auto k = a.size(1); auto n = b.size(1); - - // Define tile shapes based on the value of m - if (m <= 128) { + int sms; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device()); + + auto should_use_2sm = [&sms](int m, int n, int tile1SM = 128) { + return std::ceil(static_cast(m) / tile1SM) * + std::ceil(static_cast(n) / tile1SM) >= + sms; + }; + bool use_2sm = should_use_2sm(m, n); + if (use_2sm) { cutlass_gemm_caller_blockwise, Shape<_64, _128, _128>, - Shape<_64, _64>, Shape<_64, _1, _1>>>(out, a, b, a_scales, b_scales); + OutType, Shape<_256, _128, _128>, Shape<_256, _1, _1>, + Shape<_2, _2, _1>, cutlass::epilogue::TmaWarpSpecialized2Sm, + cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>( + out, a, b, a_scales, b_scales); } else { cutlass_gemm_caller_blockwise, Shape<_128, _128, _128>, - Shape<_128, _64>, Shape<_128, _1, _1>>>(out, a, b, a_scales, b_scales); + OutType, Shape<_128, _128, _128>, Shape<_128, _1, _1>, + Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm, + cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>( + out, a, b, a_scales, b_scales); } } From 9bab9f6a37e96e27c29d802781f7e257861da1ef Mon Sep 17 00:00:00 2001 From: shuw Date: Fri, 25 Apr 2025 21:06:43 +0000 Subject: [PATCH 6/7] Add comments for multiple of 4 constraint Signed-off-by: Shu Wang --- .../layers/quantization/utils/fp8_utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 9475de9f0b80..a7a19579a65f 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -58,13 +58,15 @@ def apply_w8a8_block_fp8_linear( shape_supported_by_cutlass = False if cutlass_block_fp8_supported and shape_supported_by_cutlass: rows, cols = input_2d.shape + # Blackwell GPUs (SM100) require row dimensions to be multiple of 4 for + # optimal tensor core usage. Can be removed when targeting platforms + # without this constraint. should_pad = current_platform.has_device_capability( 100) and rows % 4 != 0 if should_pad: - padding = torch.zeros((4 - (rows % 4), cols), - dtype=input_2d.dtype, - device=input_2d.device) - input_2d = torch.cat([input_2d, padding], dim=0).contiguous() + input_2d = torch.nn.functional.pad( + input_2d, (0, 0, 0, 4 - (rows % 4)), value=0 + ).contiguous() q_input, x_scale = per_token_group_quant_fp8(input_2d, block_size[1], From b6783dbb1f4c321d8932848ce68d50276a134976 Mon Sep 17 00:00:00 2001 From: Shu Wang Date: Sun, 4 May 2025 10:42:09 -0500 Subject: [PATCH 7/7] Fix precommit. Signed-off-by: Shu Wang --- vllm/model_executor/layers/quantization/utils/fp8_utils.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index a7a19579a65f..3bb42e737f10 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -64,10 +64,9 @@ def apply_w8a8_block_fp8_linear( should_pad = current_platform.has_device_capability( 100) and rows % 4 != 0 if should_pad: - input_2d = torch.nn.functional.pad( - input_2d, (0, 0, 0, 4 - (rows % 4)), value=0 - ).contiguous() - + input_2d = torch.nn.functional.pad(input_2d, + (0, 0, 0, 4 - (rows % 4)), + value=0).contiguous() q_input, x_scale = per_token_group_quant_fp8(input_2d, block_size[1], column_major_scales=True)