diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index f2652045526b..4587342c2d24 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -667,7 +667,6 @@ steps: # Quantization - pytest -v -s tests/kernels/quantization/test_cutlass_scaled_mm.py -k 'fp8' - pytest -v -s tests/kernels/quantization/test_nvfp4_quant.py - - pytest -v -s tests/kernels/quantization/test_silu_nvfp4_quant_fusion.py - pytest -v -s tests/kernels/quantization/test_nvfp4_scaled_mm.py - pytest -v -s tests/kernels/quantization/test_flashinfer_scaled_mm.py - pytest -v -s tests/kernels/quantization/test_flashinfer_nvfp4_scaled_mm.py @@ -677,7 +676,6 @@ steps: - pytest -v -s tests/compile/test_fusion_all_reduce.py - pytest -v -s tests/compile/test_fusion_attn.py::test_attention_quant_pattern - pytest -v -s tests/kernels/moe/test_flashinfer.py - - pytest -v -s tests/compile/test_silu_mul_quant_fusion.py ##### 1 GPU test ##### ##### multi gpus test ##### diff --git a/CMakeLists.txt b/CMakeLists.txt index e92e08f0d0ec..b0eb0f32e03a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -541,7 +541,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS) set(SRCS "csrc/quantization/fp4/nvfp4_quant_kernels.cu" - "csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu" "csrc/quantization/fp4/nvfp4_scaled_mm_sm120_kernels.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" @@ -560,7 +559,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS) set(SRCS "csrc/quantization/fp4/nvfp4_quant_kernels.cu" - "csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu" "csrc/quantization/fp4/nvfp4_experts_quant.cu" "csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu" "csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu") diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index 2728aa81f0c9..f7b75c48373f 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -19,13 +19,6 @@ #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) -#define VLLM_DISPATCH_CASE_HALF_TYPES(...) \ - AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ - AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) - -#define VLLM_DISPATCH_HALF_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_HALF_TYPES(__VA_ARGS__)) - // ROCm devices might use either fn or fnuz, so set up dispatch table for both. // A host-based check at runtime will create a preferred FP8 type for ROCm // such that the correct kernel is dispatched. @@ -52,15 +45,6 @@ #define VLLM_DISPATCH_FP8_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FP8_TYPES(__VA_ARGS__)) -#define AT_DISPATCH_BYTE_CASE(enum_type, ...) \ - AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, byte_t, __VA_ARGS__) - -#define VLLM_DISPATCH_CASE_BYTE_TYPES(...) \ - AT_DISPATCH_BYTE_CASE(at::ScalarType::Byte, __VA_ARGS__) - -#define VLLM_DISPATCH_BYTE_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_BYTE_TYPES(__VA_ARGS__)) - #define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__)) diff --git a/csrc/ops.h b/csrc/ops.h index 78a487201bdd..86fe848e2fd5 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -130,14 +130,6 @@ void silu_and_mul(torch::Tensor& out, torch::Tensor& input); void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& scale); -#ifndef USE_ROCM - -void silu_and_mul_nvfp4_quant(torch::Tensor& out, - torch::Tensor& output_block_scale, - torch::Tensor& input, - torch::Tensor& input_global_scale); -#endif - void mul_and_silu(torch::Tensor& out, torch::Tensor& input); void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); diff --git a/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu b/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu deleted file mode 100644 index 9bbeb0334fb9..000000000000 --- a/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu +++ /dev/null @@ -1,368 +0,0 @@ -/* - * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include -#include - -#include -#include - -#include -#include "dispatch_utils.h" - -#include "cuda_utils.h" - -namespace vllm { - -// Get type2 from type or vice versa (applied to half and bfloat16) -template -struct TypeConverter { - using Type = half2; -}; // keep for generality - -template <> -struct TypeConverter { - using Type = c10::Half; -}; - -template <> -struct TypeConverter { - using Type = half2; -}; - -template <> -struct TypeConverter<__nv_bfloat162> { - using Type = c10::BFloat16; -}; - -template <> -struct TypeConverter { - using Type = __nv_bfloat162; -}; - -#define ELTS_PER_THREAD 8 - -constexpr int CVT_FP4_ELTS_PER_THREAD = 8; -constexpr int CVT_FP4_SF_VEC_SIZE = 16; - -// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). -inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - uint32_t val; - asm volatile( - "{\n" - ".reg .b8 byte0;\n" - ".reg .b8 byte1;\n" - ".reg .b8 byte2;\n" - ".reg .b8 byte3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" - "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" - "}" - : "=r"(val) - : "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]), - "f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7])); - return val; -#else - return 0; -#endif -} - -// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). -inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - uint32_t val; - asm volatile( - "{\n" - ".reg .b8 byte0;\n" - ".reg .b8 byte1;\n" - ".reg .b8 byte2;\n" - ".reg .b8 byte3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" - "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" - "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" - "}" - : "=r"(val) - : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), - "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y)); - return val; -#else - return 0; -#endif -} - -// Fast reciprocal. -inline __device__ float reciprocal_approximate_ftz(float a) { - float b; - asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); - return b; -} - -template -__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx, - int numCols, - SFType* SFout) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || - CVT_FP4_NUM_THREADS_PER_SF == 2); - - // One pair of threads write one SF to global memory. - // TODO: stage through smem for packed STG.32 - // is it better than STG.8 from 4 threads ? - if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) { - // SF vector index (16 elements share one SF in the K dimension). - int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; - int32_t mIdx = rowIdx; - - // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] - // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] - - int32_t mTileIdx = mIdx / (32 * 4); - // SF vector size 16. - int factor = CVT_FP4_SF_VEC_SIZE * 4; - int32_t numKTiles = (numCols + factor - 1) / factor; - int64_t mTileStride = numKTiles * 32 * 4 * 4; - - int32_t kTileIdx = (kIdx / 4); - int64_t kTileStride = 32 * 4 * 4; - - // M tile layout [32, 4] is column-major. - int32_t outerMIdx = (mIdx % 32); - int64_t outerMStride = 4 * 4; - - int32_t innerMIdx = (mIdx % (32 * 4)) / 32; - int64_t innerMStride = 4; - - int32_t innerKIdx = (kIdx % 4); - int64_t innerKStride = 1; - - // Compute the global offset. - int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + - outerMIdx * outerMStride + innerMIdx * innerMStride + - innerKIdx * innerKStride; - - return reinterpret_cast(SFout) + SFOffset; - } -#endif - return nullptr; -} - -// Define a 16 bytes packed data type. -template -struct PackedVec { - typename TypeConverter::Type elts[4]; -}; - -template <> -struct PackedVec<__nv_fp8_e4m3> { - __nv_fp8x2_e4m3 elts[8]; -}; - -template -__inline__ __device__ PackedVec compute_silu(PackedVec& vec, - PackedVec& vec2) { - PackedVec result; -#pragma unroll - for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) { - if constexpr (std::is_same_v) { - half2 val(0.5f, 0.5f); - half2 t0 = __hmul2(vec.elts[i], val); - half2 t1 = __hfma2(h2tanh(t0), val, val); - half2 t2 = __hmul2(vec.elts[i], t1); - result.elts[i] = __hmul2(t2, vec2.elts[i]); - } else { - __nv_bfloat162 val(0.5f, 0.5f); - __nv_bfloat162 t0 = __hmul2(vec.elts[i], val); - __nv_bfloat162 t1 = __hfma2(h2tanh(t0), val, val); - __nv_bfloat162 t2 = __hmul2(vec.elts[i], t1); - result.elts[i] = __hmul2(t2, vec2.elts[i]); - } - } - return result; -} - -// Quantizes the provided PackedVec into the uint32_t output -template -__device__ uint32_t silu_and_cvt_warp_fp16_to_fp4(PackedVec& vec, - PackedVec& vec2, - float SFScaleVal, - uint8_t* SFout) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - PackedVec out_silu = compute_silu(vec, vec2); - // Get absolute maximum values among the local 8 values. - auto localMax = __habs2(out_silu.elts[0]); - - // Local maximum value. - #pragma unroll - for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { - localMax = __hmax2(localMax, __habs2(out_silu.elts[i])); - } - - // Get the absolute maximum among all 16 values (two threads). - localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); - // Get the final absolute maximum values. - float vecMax = float(__hmax(localMax.x, localMax.y)); - - // Get the SF (max value of the vector / max value of e2m1). - // maximum value of e2m1 = 6.0. - // TODO: use half as compute data type. - float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); - // 8 bits representation of the SF. - uint8_t fp8SFVal; - // Write the SF to global memory (STG.8). - if constexpr (UE8M0_SF) { - // Extract the 8 exponent bits from float32. - // float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits. - uint32_t tmp = reinterpret_cast(SFValue) >> 23; - fp8SFVal = tmp & 0xff; - // Convert back to fp32. - reinterpret_cast(SFValue) = tmp << 23; - } else { - // Here SFValue is always positive, so E4M3 is the same as UE4M3. - __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); - reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp; - // Convert back to fp32. - SFValue = float(tmp); - } - // Get the output scale. - // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * - // reciprocal(SFScaleVal)) - float outputScale = - SFValue != 0 ? reciprocal_approximate_ftz( - SFValue * reciprocal_approximate_ftz(SFScaleVal)) - : 0.0f; - - if (SFout) { - // Write the SF to global memory (STG.8). - *SFout = fp8SFVal; - } - - // Convert the input to float. - float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; - - #pragma unroll - for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { - if constexpr (std::is_same_v) { - fp2Vals[i] = __half22float2(out_silu.elts[i]); - } else { - fp2Vals[i] = __bfloat1622float2(out_silu.elts[i]); - } - fp2Vals[i].x *= outputScale; - fp2Vals[i].y *= outputScale; - } - - // Convert to e2m1 values. - uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals); - - // Write the e2m1 values to global memory. - return e2m1Vec; -#else - return 0; -#endif -} - -// Use UE4M3 by default. -template -__global__ void -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -__launch_bounds__(1024, 4) silu_and_cvt_fp16_to_fp4( -#else -silu_and_cvt_fp16_to_fp4( -#endif - int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, - uint32_t* out, uint32_t* SFout) { -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - using PackedVec = PackedVec; - static constexpr int CVT_FP4_NUM_THREADS_PER_SF = - (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); - static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, - "Vec size is not matched."); - - // Get the global scaling factor, which will be applied to the SF. - // Note SFScale is the same as next GEMM's alpha, which is - // (448.f / (Alpha_A / 6.f)). - float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0]; - - // Input tensor row/col loops. - for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { - for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD; - colIdx += blockDim.x) { - int64_t inOffset = - rowIdx * (numCols * 2 / CVT_FP4_ELTS_PER_THREAD) + colIdx; - int64_t inOffset2 = rowIdx * (numCols * 2 / CVT_FP4_ELTS_PER_THREAD) + - numCols / CVT_FP4_ELTS_PER_THREAD + colIdx; - PackedVec in_vec = reinterpret_cast(in)[inOffset]; - PackedVec in_vec2 = reinterpret_cast(in)[inOffset2]; - - // Get the output tensor offset. - // Same as inOffset because 8 elements are packed into one uint32_t. - int64_t outOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; - ; - auto& out_pos = out[outOffset]; - - auto sf_out = - cvt_quant_to_fp4_get_sf_out_offset( - rowIdx, colIdx, numCols, SFout); - - out_pos = silu_and_cvt_warp_fp16_to_fp4( - in_vec, in_vec2, SFScaleVal, sf_out); - } - } -#endif -} - -} // namespace vllm - -void silu_and_mul_nvfp4_quant(torch::Tensor& output, // [..., d] - torch::Tensor& output_sf, - torch::Tensor& input, // [..., 2 * d] - torch::Tensor& input_sf) { - TORCH_CHECK(input.dtype() == torch::kFloat16 || - input.dtype() == torch::kBFloat16); - int32_t m = input.size(0); - int32_t n = input.size(1) / 2; - TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16."); - int multiProcessorCount = - get_device_attribute(cudaDevAttrMultiProcessorCount, -1); - auto input_sf_ptr = static_cast(input_sf.data_ptr()); - auto sf_out = static_cast(output_sf.data_ptr()); - auto output_ptr = static_cast(output.data_ptr()); - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); - auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); - dim3 block(std::min(int(n / ELTS_PER_THREAD), 1024)); - int const numBlocksPerSM = 2048 / block.x; - dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM)); - VLLM_DISPATCH_HALF_TYPES( - input.scalar_type(), "act_and_mul_quant_kernel", [&] { - auto input_ptr = reinterpret_cast(input.data_ptr()); - VLLM_DISPATCH_BYTE_TYPES( - output.scalar_type(), "fused_act_and_mul_quant_kernel_nvfp4_type", - [&] { - vllm::silu_and_cvt_fp16_to_fp4 - <<>>( - m, n, input_ptr, input_sf_ptr, - reinterpret_cast(output_ptr), - reinterpret_cast(sf_out)); - }); - }); -} diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index b769c09adc0f..608b72440307 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -115,13 +115,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()"); ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant); -#ifndef USE_ROCM - ops.def( - "silu_and_mul_nvfp4_quant(Tensor! result, Tensor! result_block_scale, " - "Tensor input, Tensor input_global_scale) -> ()"); - ops.impl("silu_and_mul_nvfp4_quant", torch::kCUDA, &silu_and_mul_nvfp4_quant); -#endif - ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()"); ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu); diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index fcc2589e4211..0e1059e65447 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -4,41 +4,32 @@ import torch import vllm.envs as envs -from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant -# yapf conflicts with isort for this block -# yapf: disable -from vllm.compilation.activation_quant_fusion import ( - FUSED_OPS, SILU_MUL_OP, ActivationQuantFusionPass) -# yapf: enable -from vllm.compilation.fusion import QUANT_OPS +from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass +from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.config import CompilationConfig, PassConfig, VllmConfig from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.quantization.utils.quant_utils import ( - GroupShape, kFp8StaticTensorSym, kNvfp4Quant) + GroupShape) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp) from vllm.platforms import current_platform from .backend import TestBackend -FP8_DTYPE = current_platform.fp8_dtype() -FP4_DTYPE = torch.uint8 +class TestModel(torch.nn.Module): -def is_nvfp4_supported(): - return current_platform.has_device_capability(100) - - -class TestSiluMulFp8QuantModel(torch.nn.Module): - - def __init__(self, hidden_size: int, force_fp8_e4m3fnuz: bool, **kwargs): - super().__init__() + def __init__(self, hidden_size: int, force_fp8_e4m3fnuz: bool, *args, + **kwargs): + super().__init__(*args, **kwargs) self.silu_and_mul = SiluAndMul() self.wscale = torch.rand(1, dtype=torch.float32) self.scale = torch.rand(1, dtype=torch.float32) - self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t() + self.w = (torch.rand( + hidden_size, + hidden_size).to(dtype=current_platform.fp8_dtype()).t()) self.fp8_linear = Fp8LinearOp( force_fp8_e4m3fnuz=force_fp8_e4m3fnuz, @@ -54,56 +45,14 @@ def forward(self, x): input_scale=self.wscale) return x2 - def ops_in_model_before(self): - return [SILU_MUL_OP, QUANT_OPS[kFp8StaticTensorSym]] - - def ops_in_model_after(self): - return [FUSED_OPS[kFp8StaticTensorSym]] - - -class TestSiluMulNvfp4QuantModel(torch.nn.Module): - - def __init__(self, hidden_size: int, **kwargs): - super().__init__() - self.silu_and_mul = SiluAndMul() - self.w = torch.randint(256, (hidden_size, hidden_size // 2), - dtype=FP4_DTYPE) - self.wscale = torch.randn(hidden_size, - hidden_size // 16).to(dtype=FP8_DTYPE) - self.wscale2 = torch.rand(1, dtype=torch.float32) - self.scale = torch.rand(1, dtype=torch.float32) - def forward(self, x): - y = self.silu_and_mul(x) - y_quant, y_block_scale = scaled_fp4_quant(y, 1 / self.scale) - out = cutlass_scaled_fp4_mm(a=y_quant, - b=self.w, - block_scale_a=y_block_scale, - block_scale_b=self.wscale, - alpha=self.scale * self.wscale2, - out_dtype=y.dtype) - return out - - def ops_in_model_before(self): - return [SILU_MUL_OP, QUANT_OPS[kNvfp4Quant]] - - def ops_in_model_after(self): - return [FUSED_OPS[kNvfp4Quant]] - - -@pytest.mark.parametrize("num_tokens", [64]) -@pytest.mark.parametrize("hidden_size", [128]) -@pytest.mark.parametrize( - "model_class", [TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel] - if is_nvfp4_supported() else [TestSiluMulFp8QuantModel]) +@pytest.mark.parametrize("num_tokens", [256]) +@pytest.mark.parametrize("hidden_size", [64]) @pytest.mark.parametrize("force_fp8_e4m3fnuz", [True, False]) @pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm") -def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, model_class, +def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, force_fp8_e4m3fnuz): - if model_class == TestSiluMulNvfp4QuantModel and force_fp8_e4m3fnuz: - pytest.skip("Duplicate tests for NVFP4") - torch.set_default_device("cuda") torch.set_default_dtype(torch.float16) @@ -114,8 +63,7 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, model_class, fusion_pass = ActivationQuantFusionPass(config) backend = TestBackend(NoOpEliminationPass(config), fusion_pass) - model = model_class(hidden_size=hidden_size, - force_fp8_e4m3fnuz=force_fp8_e4m3fnuz) + model = TestModel(hidden_size, force_fp8_e4m3fnuz) # First dimension dynamic x = torch.rand(num_tokens, hidden_size * 2) @@ -132,8 +80,17 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size, model_class, atol=1e-3, rtol=1e-3) - # In pre-nodes, quant op should be present and fused kernels should not - backend.check_before_ops(model.ops_in_model_before()) + # Check substitution worked + pre_nodes = backend.graph_pre_pass.nodes + post_nodes = backend.graph_post_pass.nodes + + silu_and_mul_quant = torch.ops._C.silu_and_mul_quant.default + fp8_quant = torch.ops._C.static_scaled_fp8_quant.default + + # In pre-nodes, fp8 quant should be present and fused kernels should not + assert find_auto_fn_maybe(pre_nodes, silu_and_mul_quant) is None + find_auto_fn(pre_nodes, fp8_quant) - # In post-nodes, fused kernels should be present and quant op should not - backend.check_after_ops(model.ops_in_model_after()) + # In post-nodes, fused kernels should be present and fp8 quant should not + find_auto_fn(post_nodes, silu_and_mul_quant) + assert find_auto_fn_maybe(post_nodes, fp8_quant) is None diff --git a/tests/kernels/quantization/test_silu_nvfp4_quant_fusion.py b/tests/kernels/quantization/test_silu_nvfp4_quant_fusion.py deleted file mode 100644 index 969f14cc3fe6..000000000000 --- a/tests/kernels/quantization/test_silu_nvfp4_quant_fusion.py +++ /dev/null @@ -1,126 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import pytest -import torch - -from tests.kernels.utils import opcheck -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.platforms import current_platform -from vllm.scalar_type import scalar_types - -if not current_platform.has_device_capability(100): - pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.", - allow_module_level=True) - -DTYPES = [torch.float16, torch.bfloat16] -SHAPES = [(128, 64), (128, 128), (256, 64), (256, 128)] -SEEDS = [42] -CUDA_DEVICES = ['cuda:0'] - -FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() -FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max - -BLOCK_SIZE = 16 - - -def ref_impl(silu_and_mul: SiluAndMul, x: torch.Tensor, - global_scale: torch.Tensor, - ref_output_scale: torch.Tensor) -> torch.Tensor: - silu_and_mul_out = silu_and_mul.forward_native(x) - assert not current_platform.is_rocm() - assert silu_and_mul_out.ndim >= 1, ( - f'input.ndim needs to be >= 1, but got {silu_and_mul_out.ndim}.') - other_dims = 1 if silu_and_mul_out.ndim == 1 else -1 - silu_and_mul_out = silu_and_mul_out.reshape(other_dims, - silu_and_mul_out.shape[-1]) - m, n = silu_and_mul_out.shape - device = silu_and_mul_out.device - - # Two fp4 values will be packed into an uint8. - out = torch.empty((m, n // 2), device=device, dtype=torch.uint8) - - output_scale = ref_output_scale - - torch.ops._C.scaled_fp4_quant(out, silu_and_mul_out, output_scale, - global_scale) - - return out, output_scale - - -def ops_impl(x: torch.Tensor, global_scale: torch.Tensor, - ref_output_scale: torch.Tensor) -> torch.Tensor: - out_shape = (x.shape[0], x.shape[1] // 4) - output_scale = ref_output_scale - out = torch.empty(out_shape, dtype=torch.uint8, device=x.device) - torch.ops._C.silu_and_mul_nvfp4_quant(out, output_scale, x, global_scale) - return out, output_scale - - -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("shape", SHAPES) -@pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -@torch.inference_mode() -def test_quantize_to_fp4( - dtype: torch.dtype, - shape: tuple[int, int], - seed: int, - device: str, -) -> None: - current_platform.seed_everything(seed) - torch.set_default_device(device) - - m, n = shape - - x = torch.randn((m, n), dtype=dtype) - tensor_amax = torch.abs(x).max().to(torch.float32) - global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax - - block_size = 16 - - assert n % block_size == 0, ( - f'last dim has to be multiple of 16, but got {n}.') - assert x.dtype in (torch.float16, torch.bfloat16), ( - f'input.dtype needs to be fp16 or bf16 but got {x.dtype}.') - - round_up = lambda x, y: (x + y - 1) // y * y - rounded_m = round_up(x.shape[0], 128) - scale_n = x.shape[1] // (2 * block_size) - rounded_n = round_up(scale_n, 4) - output_scale = torch.empty((rounded_m, rounded_n // 4), - device=x.device, - dtype=torch.int32) - - layer = SiluAndMul() - - ref_out, ref_out_scale = ref_impl(layer, x, global_scale, output_scale) - - fusion_out, fusion_out_scale = ops_impl(x, global_scale, output_scale) - - assert ref_out.dtype == torch.uint8 - assert fusion_out.dtype == torch.uint8 - assert ref_out.shape == fusion_out.shape - - assert ref_out_scale.dtype == torch.int32 - assert fusion_out_scale.dtype == torch.int32 - assert ref_out_scale.shape == fusion_out_scale.shape - - # Allow up to 2% of mismatched values since BF16 has accuracy issues. - mis_threshold = 0.02 - atol = 0.4 - rtol = 0.4 - ref_logits = ref_out[-1] - fusion_logits = fusion_out[-1] - - mis_count = torch.sum( - torch.abs(fusion_logits - ref_logits) > (atol + - rtol * torch.abs(ref_logits))) - mis_ratio = mis_count / fusion_logits.numel() - - assert mis_ratio < mis_threshold, \ - f"Mismatch ratio {mis_ratio} exceeds threshold {mis_threshold}" - - torch.testing.assert_close(ref_out_scale, fusion_out_scale) - - opcheck(torch.ops._C.silu_and_mul_nvfp4_quant, - (fusion_out, fusion_out_scale, x, global_scale)) diff --git a/vllm/compilation/activation_quant_fusion.py b/vllm/compilation/activation_quant_fusion.py index 40e124a03eb0..826014f770df 100644 --- a/vllm/compilation/activation_quant_fusion.py +++ b/vllm/compilation/activation_quant_fusion.py @@ -1,154 +1,55 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from abc import ABC, abstractmethod - import torch from torch._higher_order_ops.auto_functionalize import auto_functionalized from torch._inductor.pattern_matcher import (PatternMatcherPass, fwd_only, register_replacement) -from torch._ops import OpOverload from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - QuantKey, kFp8StaticTensorSym, kNvfp4Quant, kStaticTensorScale) from vllm.platforms import current_platform -from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32 from .inductor_pass import enable_fake_mode from .vllm_inductor_pass import VllmInductorPass logger = init_logger(__name__) -FP8_DTYPE = current_platform.fp8_dtype() -FP4_DTYPE = torch.uint8 - -SILU_MUL_OP = torch.ops._C.silu_and_mul.default -FUSED_OPS: dict[QuantKey, OpOverload] = { - kFp8StaticTensorSym: torch.ops._C.silu_and_mul_quant.default, # noqa: E501 -} -if current_platform.is_cuda() and hasattr(torch.ops._C, - "silu_and_mul_nvfp4_quant"): - FUSED_OPS[ - kNvfp4Quant] = torch.ops._C.silu_and_mul_nvfp4_quant.default # noqa: E501 +def silu_mul_pattern_static(result: torch.Tensor, + result_silu_mul: torch.Tensor, input: torch.Tensor, + scale: torch.Tensor): + at1 = auto_functionalized(torch.ops._C.silu_and_mul.default, + result=result_silu_mul, + input=input) + at2 = auto_functionalized(torch.ops._C.static_scaled_fp8_quant.default, + result=result, + input=at1[1], + scale=scale) + return at2[1] -class ActivationQuantPattern(ABC): - """ - The base class for Activation+Quant fusions. - Should not be used directly. - """ +def silu_mul_replacement_static(result: torch.Tensor, + result_silu_mul: torch.Tensor, + input: torch.Tensor, scale: torch.Tensor): + at = auto_functionalized(torch.ops._C.silu_and_mul_quant.default, + result=result, + input=input, + scale=scale) + return at[1] - def __init__( - self, - quant_key: QuantKey, - ): - self.quant_key = quant_key - self.quant_dtype = quant_key.dtype - assert self.quant_key in QUANT_OPS, \ - f"unsupported quantization scheme {self.quant_key}" - self.QUANT_OP = QUANT_OPS[self.quant_key] +def empty_bf16(*args, **kwargs): + return torch.empty(*args, **kwargs, dtype=torch.bfloat16, device="cuda") - assert self.quant_key in FUSED_OPS, \ - f"unsupported fusion scheme {self.quant_key}" - self.FUSED_OP = FUSED_OPS[self.quant_key] - def empty_quant(self, *args, **kwargs): - kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs} - return torch.empty(*args, **kwargs) +def empty_fp8(*args, **kwargs): + fp8 = current_platform.fp8_dtype() + return torch.empty(*args, **kwargs, dtype=fp8, device="cuda") - @abstractmethod - def register(self, pm_pass: PatternMatcherPass): - raise NotImplementedError - -class SiluMulFp8StaticQuantPattern(ActivationQuantPattern): - """ - Fusion for SiluMul+Fp8StaticQuant Pattern - """ - - def __init__(self, symmetric: bool = True): - quant_key = QuantKey(dtype=FP8_DTYPE, - scale=kStaticTensorScale, - symmetric=symmetric) - super().__init__(quant_key) - - def register(self, pm_pass: PatternMatcherPass): - - def pattern(result: torch.Tensor, result_silu_mul: torch.Tensor, - input: torch.Tensor, scale: torch.Tensor): - at1 = auto_functionalized(SILU_MUL_OP, - result=result_silu_mul, - input=input) - at2 = auto_functionalized(self.QUANT_OP, - result=result, - input=at1[1], - scale=scale) - return at2[1] - - def replacement(result: torch.Tensor, result_silu_mul: torch.Tensor, - input: torch.Tensor, scale: torch.Tensor): - at = auto_functionalized(self.FUSED_OP, - result=result, - input=input, - scale=scale) - return at[1] - - inputs = [ - self.empty_quant(5, 4), # result - empty_bf16(5, 4), # result_silu_mul - empty_bf16(5, 4), # input - empty_fp32(1, 1) # scale - ] - - register_replacement(pattern, replacement, inputs, fwd_only, pm_pass) - - -class SiluMulNvfp4QuantPattern(ActivationQuantPattern): - """ - Fusion for SiluMul+Nvfp4Quant Pattern - """ - - def __init__(self): - super().__init__(kNvfp4Quant) - - def register(self, pm_pass: PatternMatcherPass): - - def pattern(result: torch.Tensor, output_scale: torch.Tensor, - result_silu_mul: torch.Tensor, input: torch.Tensor, - scale: torch.Tensor): - at1 = auto_functionalized(SILU_MUL_OP, - result=result_silu_mul, - input=input) - at2 = auto_functionalized(self.QUANT_OP, - output=result, - input=at1[1], - output_scale=output_scale, - input_scale=scale) - return at2[1], at2[2] - - def replacement(result: torch.Tensor, output_scale: torch.Tensor, - result_silu_mul: torch.Tensor, input: torch.Tensor, - scale: torch.Tensor): - at = auto_functionalized(self.FUSED_OP, - result=result, - result_block_scale=output_scale, - input=input, - input_global_scale=scale) - return at[1], at[2] - - inputs = [ - self.empty_quant(5, 32), # result - empty_i32(128, 4), # output_scale - empty_bf16(5, 64), # result_silu_mul - empty_bf16(5, 64), # input - empty_fp32(1, 1) # scale - ] - - register_replacement(pattern, replacement, inputs, fwd_only, pm_pass) +def empty_fp32(*args, **kwargs): + return torch.empty(*args, **kwargs, dtype=torch.float32, device="cuda") class ActivationQuantFusionPass(VllmInductorPass): @@ -168,11 +69,15 @@ def __init__(self, config: VllmConfig): self.patterns: PatternMatcherPass = PatternMatcherPass( pass_name="activation_quant_fusion_pass") - pattern_silu_mul_fp8 = SiluMulFp8StaticQuantPattern() - pattern_silu_mul_fp8.register(self.patterns) - - pattern_silu_mul_nvfp4 = SiluMulNvfp4QuantPattern() - pattern_silu_mul_nvfp4.register(self.patterns) + inputs = [ + empty_fp8(5, 4), # Quant output + empty_bf16(5, 4), # Silu_and_mul output + empty_bf16(5, 4), # Input + empty_fp32(1, 1) # Scale + ] + register_replacement(silu_mul_pattern_static, + silu_mul_replacement_static, inputs, fwd_only, + self.patterns) def __call__(self, graph: torch.fx.Graph): self.begin() @@ -184,8 +89,3 @@ def __call__(self, graph: torch.fx.Graph): self.dump_graph(graph, "after_act_quant_fusion") self.end_and_log() - - def uuid(self): - return VllmInductorPass.hash_source(self, ActivationQuantPattern, - SiluMulFp8StaticQuantPattern, - SiluMulNvfp4QuantPattern) diff --git a/vllm/compilation/fix_functionalization.py b/vllm/compilation/fix_functionalization.py index a36dd8b845f1..60ae14331879 100644 --- a/vllm/compilation/fix_functionalization.py +++ b/vllm/compilation/fix_functionalization.py @@ -97,13 +97,6 @@ def __call__(self, graph: torch.fx.Graph): node, mutated_args, args=('result', 'input', 'scale')) - elif at_target == torch.ops._C.silu_and_mul_nvfp4_quant.default: - mutated_args = {1: 'result', 2: 'result_block_scale'} - self.defunctionalize(graph, - node, - mutated_args, - args=('result', 'result_block_scale', - 'input', 'input_global_scale')) else: continue # skip the count diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 1fbb2e3bb6f2..9d4e453ffc54 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -885,10 +885,6 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.alpha = Parameter(layer.input_scale * layer.weight_scale_2, requires_grad=False) - # Calculate `1 / input_scale` so that we don't need to do so at runtime - layer.input_scale_inv = Parameter( - (1 / layer.input_scale).to(torch.float32), requires_grad=False) - # Swizzle the weight blockscale. # contracting dimension is input dimension # block_size = 16; @@ -945,7 +941,8 @@ def apply( output_shape = [x.shape[0], layer.weight.shape[0]] # quantize BF16 or FP16 to (FP4 and interleaved block scale) - x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv) + s_quant = 1 / layer.input_scale + x_fp4, x_blockscale = scaled_fp4_quant(x, s_quant) # validate dtypes of quantized input, input block scale, # weight and weight_blockscale